diff --git a/.env b/.env new file mode 100644 index 000000000..79f19d23a --- /dev/null +++ b/.env @@ -0,0 +1,66 @@ +# You only need to fill in the values when running the examples, see examples/ + +# For using GPT on OpenAI +OPENAI_API_KEY= + +# For using Claude on Anthropic +ANTHROPIC_API_KEY= + +# For using Mistral +MISTRAL_API_KEY= + +# For using Voyage +VOYAGE_API_KEY= + +# For using Replicate +REPLICATE_API_KEY= + +# For using Ollama +OLLAMA_HOST_URL= + +# For using GPT on Azure +AZURE_OPENAI_BASEURL= +AZURE_OPENAI_KEY= +AZURE_OPENAI_GPT_DEPLOYMENT= +AZURE_OPENAI_GPT_API_VERSION= +AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT= +AZURE_OPENAI_EMBEDDINGS_API_VERSION= +AZURE_OPENAI_WHISPER_DEPLOYMENT= +AZURE_OPENAI_WHISPER_API_VERSION= + +# For using Llama on Azure +AZURE_LLAMA_BASEURL= +AZURE_LLAMA_KEY= + +# For using Bedrock +AWS_ACCESS_KEY_ID= +AWS_SECRET_ACCESS_KEY= +AWS_DEFAULT_REGION= + +# Hugging Face Access Token +HUGGINGFACE_KEY= + +# For using OpenRouter +OPENROUTER_KEY= + +# For using SerpApi (tool) +SERP_API_KEY= + +# For using Tavily (tool) +TAVILY_API_KEY= + +# For using Brave (tool) +BRAVE_API_KEY= + +# For using MongoDB Atlas (store) +MONGODB_URI= + +# For using Pinecone (store) +PINECONE_API_KEY= +PINECONE_HOST= + +# Some examples are expensive to run, so we disable them by default +RUN_EXPENSIVE_EXAMPLES=false + +# For using Gemini +GOOGLE_API_KEY= diff --git a/.gitignore b/.gitignore index 3c8128ab5..1abc003df 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ .doctor-rst.cache .php-cs-fixer.cache .phpunit.result.cache +.env.local /composer.lock /vendor diff --git a/composer.json b/composer.json index 0024d3c17..7d404e74c 100644 --- a/composer.json +++ b/composer.json @@ -6,8 +6,25 @@ ], "require-dev": { "php": ">=8.1", + "symfony/ai-agent": "@dev", + "symfony/ai-platform": "@dev", + "symfony/console": "^6.4|^7.0", + "symfony/css-selector": "^6.4|^7.0", + "symfony/dom-crawler": "^6.4|^7.0", + "symfony/dotenv": "^6.4|^7.0", "symfony/filesystem": "^6.4|^7.0", "symfony/finder": "^6.4|^7.0", + "symfony/process": "^6.4|^7.0", + "symfony/var-dumper": "^6.4|^7.0", "php-cs-fixer/shim": "^3.75" + }, + "repositories": [ + { "type": "path", "url": "src/agent" }, + { "type": "path", "url": "src/platform" } + ], + "autoload-dev": { + "psr-4": { + "Symfony\\AI\\Fixtures\\": "fixtures/" + } } } diff --git a/example b/example new file mode 100755 index 000000000..bcb6a135b --- /dev/null +++ b/example @@ -0,0 +1,110 @@ +#!/usr/bin/env php +setDescription('Runs all Symfony AI examples in folder examples/') + ->addArgument('subdirectory', InputArgument::OPTIONAL, 'Subdirectory to run examples from, e.g. "anthropic" or "huggingface".') + ->setCode(function (InputInterface $input, OutputInterface $output) { + $io = new SymfonyStyle($input, $output); + $io->title('Symfony AI Examples'); + + $directory = __DIR__.'/examples'; + + if ($subdirectory = $input->getArgument('subdirectory')) { + $directory .= '/'.$subdirectory; + if (!is_dir($directory)) { + $io->error(sprintf('Subdirectory "%s" does not exist.', $subdirectory)); + return Command::FAILURE; + } + } + + $examples = (new Finder()) + ->in($directory) + ->name('*.php') + ->sortByName() + ->files(); + + /** @var array{example: SplFileInfo, process: Process} $exampleRuns */ + $exampleRuns = []; + foreach ($examples as $example) { + $exampleRuns[] = [ + 'example' => $example, + 'process' => $process = new Process(['php', $example->getRealPath()]), + ]; + $process->start(); + } + + $section = $output->section(); + $renderTable = function () use ($exampleRuns, $section) { + $section->clear(); + $table = new Table($section); + $table->setHeaders(['Example', 'State', 'Output']); + foreach ($exampleRuns as $run) { + /** @var SplFileInfo $example */ + /** @var Process $process */ + ['example' => $example, 'process' => $process] = $run; + + $output = str_replace(PHP_EOL, ' ', $process->getOutput()); + $output = strlen($output) <= 100 ? $output : substr($output, 0, 100).'...'; + $emptyOutput = 0 === strlen(trim($output)); + + $state = 'Running'; + if ($process->isTerminated()) { + $success = $process->isSuccessful() && !$emptyOutput; + $state = $success ? 'Finished' + : (1 === $run['process']->getExitCode() || $emptyOutput ? 'Failed' : 'Skipped'); + } + + $table->addRow([$example->getRelativePathname(), $state, $output]); + } + $table->render(); + }; + + $examplesRunning = fn () => array_reduce($exampleRuns, fn ($running, $example) => $running || $example['process']->isRunning(), false); + while ($examplesRunning()) { + $renderTable(); + sleep(1); + } + + $renderTable(); + $io->newLine(); + + $successCount = array_reduce($exampleRuns, function ($count, $example) { + if ($example['process']->isSuccessful() && strlen(trim($example['process']->getOutput())) > 0) { + return $count + 1; + } + return $count; + }, 0); + + $totalCount = count($exampleRuns); + + if ($successCount < $totalCount) { + $io->warning(sprintf('%d out of %d examples ran successfully.', $successCount, $totalCount)); + } else { + $io->success(sprintf('All %d examples ran successfully!', $totalCount)); + } + + foreach ($exampleRuns as $run) { + if (!$run['process']->isSuccessful()) { + $io->section('Error in ' . $run['example']->getRelativePathname()); + $io->text($run['process']->getOutput()); + $io->text($run['process']->getErrorOutput()); + } + } + + return Command::SUCCESS; + }) + ->run(); diff --git a/examples/anthropic/chat.php b/examples/anthropic/chat.php new file mode 100644 index 000000000..b33e34f2f --- /dev/null +++ b/examples/anthropic/chat.php @@ -0,0 +1,28 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['ANTHROPIC_API_KEY'])) { + echo 'Please set the ANTHROPIC_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['ANTHROPIC_API_KEY']); +$model = new Claude(Claude::SONNET_37); + +$agent = new Agent($platform, $model); +$messages = new MessageBag( + Message::forSystem('You are a pirate and you write funny.'), + Message::ofUser('What is the Symfony framework?'), +); +$response = $agent->call($messages); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/anthropic/image-input-binary.php b/examples/anthropic/image-input-binary.php new file mode 100644 index 000000000..6df96ae64 --- /dev/null +++ b/examples/anthropic/image-input-binary.php @@ -0,0 +1,32 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['ANTHROPIC_API_KEY'])) { + echo 'Please set the ANTHROPIC_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['ANTHROPIC_API_KEY']); +$model = new Claude(Claude::SONNET_37); + +$agent = new Agent($platform, $model); +$messages = new MessageBag( + Message::forSystem('You are an image analyzer bot that helps identify the content of images.'), + Message::ofUser( + Image::fromFile(dirname(__DIR__, 2).'/tests/Fixture/image.jpg'), + 'Describe this image.', + ), +); +$response = $agent->call($messages); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/anthropic/image-input-url.php b/examples/anthropic/image-input-url.php new file mode 100644 index 000000000..cceed03c4 --- /dev/null +++ b/examples/anthropic/image-input-url.php @@ -0,0 +1,32 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['ANTHROPIC_API_KEY'])) { + echo 'Please set the ANTHROPIC_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['ANTHROPIC_API_KEY']); +$model = new Claude(Claude::SONNET_37); + +$agent = new Agent($platform, $model); +$messages = new MessageBag( + Message::forSystem('You are an image analyzer bot that helps identify the content of images.'), + Message::ofUser( + new ImageUrl('https://upload.wikimedia.org/wikipedia/commons/a/a7/Camponotus_flavomarginatus_ant.jpg'), + 'Describe this image.', + ), +); +$response = $agent->call($messages); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/anthropic/pdf-input-binary.php b/examples/anthropic/pdf-input-binary.php new file mode 100644 index 000000000..f74beb025 --- /dev/null +++ b/examples/anthropic/pdf-input-binary.php @@ -0,0 +1,31 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['ANTHROPIC_API_KEY'])) { + echo 'Please set the ANTHROPIC_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['ANTHROPIC_API_KEY']); +$model = new Claude(Claude::SONNET_37); + +$agent = new Agent($platform, $model); +$messages = new MessageBag( + Message::ofUser( + Document::fromFile(dirname(__DIR__, 2).'/tests/Fixture/document.pdf'), + 'What is this document about?', + ), +); +$response = $agent->call($messages); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/anthropic/pdf-input-url.php b/examples/anthropic/pdf-input-url.php new file mode 100644 index 000000000..cbe4211b2 --- /dev/null +++ b/examples/anthropic/pdf-input-url.php @@ -0,0 +1,31 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['ANTHROPIC_API_KEY'])) { + echo 'Please set the ANTHROPIC_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['ANTHROPIC_API_KEY']); +$model = new Claude(Claude::SONNET_37); + +$agent = new Agent($platform, $model); +$messages = new MessageBag( + Message::ofUser( + new DocumentUrl('https://upload.wikimedia.org/wikipedia/commons/2/20/Re_example.pdf'), + 'What is this document about?', + ), +); +$response = $agent->call($messages); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/anthropic/stream.php b/examples/anthropic/stream.php new file mode 100644 index 000000000..757acfbf4 --- /dev/null +++ b/examples/anthropic/stream.php @@ -0,0 +1,33 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['ANTHROPIC_API_KEY'])) { + echo 'Please set the ANTHROPIC_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['ANTHROPIC_API_KEY']); +$model = new Claude(); + +$agent = new Agent($platform, $model); +$messages = new MessageBag( + Message::forSystem('You are a thoughtful philosopher.'), + Message::ofUser('What is the purpose of an ant?'), +); +$response = $agent->call($messages, [ + 'stream' => true, // enable streaming of response text +]); + +foreach ($response->getContent() as $word) { + echo $word; +} +echo \PHP_EOL; diff --git a/examples/anthropic/toolcall.php b/examples/anthropic/toolcall.php new file mode 100644 index 000000000..16ba6616a --- /dev/null +++ b/examples/anthropic/toolcall.php @@ -0,0 +1,33 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['ANTHROPIC_API_KEY'])) { + echo 'Please set the ANTHROPIC_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['ANTHROPIC_API_KEY']); +$model = new Claude(); + +$wikipedia = new Wikipedia(HttpClient::create()); +$toolbox = Toolbox::create($wikipedia); +$processor = new AgentProcessor($toolbox); +$agent = new Agent($platform, $model, [$processor], [$processor]); + +$messages = new MessageBag(Message::ofUser('Who is the current chancellor of Germany?')); +$response = $agent->call($messages); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/azure/audio-transcript.php b/examples/azure/audio-transcript.php new file mode 100644 index 000000000..c25706f68 --- /dev/null +++ b/examples/azure/audio-transcript.php @@ -0,0 +1,28 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['AZURE_OPENAI_BASEURL']) || empty($_ENV['AZURE_OPENAI_WHISPER_DEPLOYMENT']) || empty($_ENV['AZURE_OPENAI_WHISPER_API_VERSION']) || empty($_ENV['AZURE_OPENAI_KEY']) +) { + echo 'Please set the AZURE_OPENAI_BASEURL, AZURE_OPENAI_WHISPER_DEPLOYMENT, AZURE_OPENAI_WHISPER_API_VERSION, and AZURE_OPENAI_KEY environment variables.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create( + $_ENV['AZURE_OPENAI_BASEURL'], + $_ENV['AZURE_OPENAI_WHISPER_DEPLOYMENT'], + $_ENV['AZURE_OPENAI_WHISPER_API_VERSION'], + $_ENV['AZURE_OPENAI_KEY'], +); +$model = new Whisper(); +$file = Audio::fromFile(dirname(__DIR__, 2).'/tests/Fixture/audio.mp3'); + +$response = $platform->request($model, $file); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/azure/chat-gpt.php b/examples/azure/chat-gpt.php new file mode 100644 index 000000000..b38a8aaae --- /dev/null +++ b/examples/azure/chat-gpt.php @@ -0,0 +1,34 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['AZURE_OPENAI_BASEURL']) || empty($_ENV['AZURE_OPENAI_GPT_DEPLOYMENT']) || empty($_ENV['AZURE_OPENAI_GPT_API_VERSION']) || empty($_ENV['AZURE_OPENAI_KEY']) +) { + echo 'Please set the AZURE_OPENAI_BASEURL, AZURE_OPENAI_GPT_DEPLOYMENT, AZURE_OPENAI_GPT_API_VERSION, and AZURE_OPENAI_KEY environment variables.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create( + $_ENV['AZURE_OPENAI_BASEURL'], + $_ENV['AZURE_OPENAI_GPT_DEPLOYMENT'], + $_ENV['AZURE_OPENAI_GPT_API_VERSION'], + $_ENV['AZURE_OPENAI_KEY'], +); +$model = new GPT(GPT::GPT_4O_MINI); + +$agent = new Agent($platform, $model); +$messages = new MessageBag( + Message::forSystem('You are a pirate and you write funny.'), + Message::ofUser('What is the Symfony framework?'), +); +$response = $agent->call($messages); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/azure/chat-llama.php b/examples/azure/chat-llama.php new file mode 100644 index 000000000..931c33c89 --- /dev/null +++ b/examples/azure/chat-llama.php @@ -0,0 +1,31 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['AZURE_LLAMA_BASEURL']) || empty($_ENV['AZURE_LLAMA_KEY'])) { + echo 'Please set the AZURE_LLAMA_BASEURL and AZURE_LLAMA_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['AZURE_LLAMA_BASEURL'], $_ENV['AZURE_LLAMA_KEY']); +$model = new Llama(Llama::V3_3_70B_INSTRUCT); + +$agent = new Agent($platform, $model); +$messages = new MessageBag(Message::ofUser('I am going to Paris, what should I see?')); +$response = $agent->call($messages, [ + 'max_tokens' => 2048, + 'temperature' => 0.8, + 'top_p' => 0.1, + 'presence_penalty' => 0, + 'frequency_penalty' => 0, +]); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/azure/embeddings.php b/examples/azure/embeddings.php new file mode 100644 index 000000000..407063256 --- /dev/null +++ b/examples/azure/embeddings.php @@ -0,0 +1,33 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['AZURE_OPENAI_BASEURL']) || empty($_ENV['AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT']) || empty($_ENV['AZURE_OPENAI_EMBEDDINGS_API_VERSION']) || empty($_ENV['AZURE_OPENAI_KEY']) +) { + echo 'Please set the AZURE_OPENAI_BASEURL, AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT, AZURE_OPENAI_EMBEDDINGS_API_VERSION, and AZURE_OPENAI_KEY environment variables.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create( + $_ENV['AZURE_OPENAI_BASEURL'], + $_ENV['AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT'], + $_ENV['AZURE_OPENAI_EMBEDDINGS_API_VERSION'], + $_ENV['AZURE_OPENAI_KEY'], +); +$embeddings = new Embeddings(); + +$response = $platform->request($embeddings, <<getContent()[0]->getDimensions().\PHP_EOL; diff --git a/examples/bedrock/chat-claude.php b/examples/bedrock/chat-claude.php new file mode 100644 index 000000000..550ad978e --- /dev/null +++ b/examples/bedrock/chat-claude.php @@ -0,0 +1,29 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['AWS_ACCESS_KEY_ID']) || empty($_ENV['AWS_SECRET_ACCESS_KEY']) || empty($_ENV['AWS_DEFAULT_REGION']) +) { + echo 'Please set the AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY and AWS_DEFAULT_REGION environment variables.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create(); +$model = new Claude(); + +$agent = new Agent($platform, $model); +$messages = new MessageBag( + Message::forSystem('You answer questions in short and concise manner.'), + Message::ofUser('What is the Symfony framework?'), +); +$response = $agent->call($messages); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/bedrock/chat-llama.php b/examples/bedrock/chat-llama.php new file mode 100644 index 000000000..4cf1974f0 --- /dev/null +++ b/examples/bedrock/chat-llama.php @@ -0,0 +1,29 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['AWS_ACCESS_KEY_ID']) || empty($_ENV['AWS_SECRET_ACCESS_KEY']) || empty($_ENV['AWS_DEFAULT_REGION']) +) { + echo 'Please set the AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY and AWS_DEFAULT_REGION environment variables.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create(); +$model = new Llama(Llama::V3_2_3B_INSTRUCT); + +$agent = new Agent($platform, $model); +$messages = new MessageBag( + Message::forSystem('You are a pirate and you write funny.'), + Message::ofUser('What is the Symfony framework?'), +); +$response = $agent->call($messages); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/bedrock/chat-nova.php b/examples/bedrock/chat-nova.php new file mode 100644 index 000000000..6ac21c5db --- /dev/null +++ b/examples/bedrock/chat-nova.php @@ -0,0 +1,29 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['AWS_ACCESS_KEY_ID']) || empty($_ENV['AWS_SECRET_ACCESS_KEY']) || empty($_ENV['AWS_DEFAULT_REGION']) +) { + echo 'Please set the AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY and AWS_DEFAULT_REGION environment variables.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create(); +$model = new Nova(Nova::PRO); + +$agent = new Agent($platform, $model); +$messages = new MessageBag( + Message::forSystem('You are a pirate and you write funny.'), + Message::ofUser('What is the Symfony framework?'), +); +$response = $agent->call($messages); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/bedrock/image-claude-binary.php b/examples/bedrock/image-claude-binary.php new file mode 100644 index 000000000..91be24413 --- /dev/null +++ b/examples/bedrock/image-claude-binary.php @@ -0,0 +1,33 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['AWS_ACCESS_KEY_ID']) || empty($_ENV['AWS_SECRET_ACCESS_KEY']) || empty($_ENV['AWS_DEFAULT_REGION']) +) { + echo 'Please set the AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY and AWS_DEFAULT_REGION environment variables.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create(); +$model = new Claude(); + +$agent = new Agent($platform, $model); +$messages = new MessageBag( + Message::forSystem('You are an image analyzer bot that helps identify the content of images.'), + Message::ofUser( + 'Describe the image as a comedian would do it.', + Image::fromFile(dirname(__DIR__, 2).'/tests/Fixture/image.jpg'), + ), +); +$response = $agent->call($messages); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/bedrock/image-nova.php b/examples/bedrock/image-nova.php new file mode 100644 index 000000000..9245ae523 --- /dev/null +++ b/examples/bedrock/image-nova.php @@ -0,0 +1,33 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['AWS_ACCESS_KEY_ID']) || empty($_ENV['AWS_SECRET_ACCESS_KEY']) || empty($_ENV['AWS_DEFAULT_REGION']) +) { + echo 'Please set the AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY and AWS_DEFAULT_REGION environment variables.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create(); +$model = new Nova(Nova::PRO); + +$agent = new Agent($platform, $model); +$messages = new MessageBag( + Message::forSystem('You are an image analyzer bot that helps identify the content of images.'), + Message::ofUser( + 'Describe the image as a comedian would do it.', + Image::fromFile(dirname(__DIR__, 2).'/tests/Fixture/image.jpg'), + ), +); +$response = $agent->call($messages); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/bedrock/toolcall-claude.php b/examples/bedrock/toolcall-claude.php new file mode 100644 index 000000000..a6a6d7746 --- /dev/null +++ b/examples/bedrock/toolcall-claude.php @@ -0,0 +1,34 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['AWS_ACCESS_KEY_ID']) || empty($_ENV['AWS_SECRET_ACCESS_KEY']) || empty($_ENV['AWS_DEFAULT_REGION']) +) { + echo 'Please set the AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY and AWS_DEFAULT_REGION environment variables.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create(); +$model = new Claude(); + +$wikipedia = new Wikipedia(HttpClient::create()); +$toolbox = Toolbox::create($wikipedia); +$processor = new AgentProcessor($toolbox); +$agent = new Agent($platform, $model, [$processor], [$processor]); + +$messages = new MessageBag(Message::ofUser('Who is the current chancellor of Germany?')); +$response = $agent->call($messages); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/bedrock/toolcall-nova.php b/examples/bedrock/toolcall-nova.php new file mode 100644 index 000000000..f00306a8b --- /dev/null +++ b/examples/bedrock/toolcall-nova.php @@ -0,0 +1,36 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['AWS_ACCESS_KEY_ID']) || empty($_ENV['AWS_SECRET_ACCESS_KEY']) || empty($_ENV['AWS_DEFAULT_REGION']) +) { + echo 'Please set the AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY and AWS_DEFAULT_REGION environment variables.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create(); +$model = new Nova(); + +$wikipedia = new Wikipedia(HttpClient::create()); +$toolbox = Toolbox::create($wikipedia); +$processor = new AgentProcessor($toolbox); +$agent = new Agent($platform, $model, [$processor], [$processor]); + +$messages = new MessageBag( + Message::ofUser('Who is the current chancellor of Germany? Use Wikipedia to find the answer.') +); +$response = $agent->call($messages); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/chat-system-prompt.php b/examples/chat-system-prompt.php new file mode 100644 index 000000000..1107f928d --- /dev/null +++ b/examples/chat-system-prompt.php @@ -0,0 +1,28 @@ +loadEnv(dirname(__DIR__).'/.env'); + +if (empty($_ENV['OPENAI_API_KEY'])) { + echo 'Please set the OPENAI_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['OPENAI_API_KEY']); +$model = new GPT(GPT::GPT_4O_MINI); + +$processor = new SystemPromptInputProcessor('You are Yoda and write like he speaks. But short.'); + +$agent = new Agent($platform, $model, [$processor]); +$messages = new MessageBag(Message::ofUser('What is the meaning of life?')); +$response = $agent->call($messages); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/google/chat.php b/examples/google/chat.php new file mode 100644 index 000000000..c78eb5145 --- /dev/null +++ b/examples/google/chat.php @@ -0,0 +1,28 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['GOOGLE_API_KEY'])) { + echo 'Please set the GOOGLE_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['GOOGLE_API_KEY']); +$model = new Gemini(Gemini::GEMINI_2_FLASH); + +$agent = new Agent($platform, $model); +$messages = new MessageBag( + Message::forSystem('You are a pirate and you write funny.'), + Message::ofUser('What is the Symfony framework?'), +); +$response = $agent->call($messages); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/google/image-input.php b/examples/google/image-input.php new file mode 100644 index 000000000..7e41e3a97 --- /dev/null +++ b/examples/google/image-input.php @@ -0,0 +1,32 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['GOOGLE_API_KEY'])) { + echo 'Please set the GOOGLE_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['GOOGLE_API_KEY']); +$model = new Gemini(Gemini::GEMINI_1_5_FLASH); + +$agent = new Agent($platform, $model); +$messages = new MessageBag( + Message::forSystem('You are an image analyzer bot that helps identify the content of images.'), + Message::ofUser( + 'Describe the image as a comedian would do it.', + Image::fromFile(dirname(__DIR__, 2).'/tests/Fixture/image.jpg'), + ), +); +$response = $agent->call($messages); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/google/stream.php b/examples/google/stream.php new file mode 100644 index 000000000..aa8caa3f9 --- /dev/null +++ b/examples/google/stream.php @@ -0,0 +1,33 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['GOOGLE_API_KEY'])) { + echo 'Please set the GOOGLE_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['GOOGLE_API_KEY']); +$model = new Gemini(Gemini::GEMINI_2_FLASH); + +$agent = new Agent($platform, $model); +$messages = new MessageBag( + Message::forSystem('You are a funny clown that entertains people.'), + Message::ofUser('What is the purpose of an ant?'), +); +$response = $agent->call($messages, [ + 'stream' => true, // enable streaming of response text +]); + +foreach ($response->getContent() as $word) { + echo $word; +} +echo \PHP_EOL; diff --git a/examples/huggingface/_model-listing.php b/examples/huggingface/_model-listing.php new file mode 100644 index 000000000..3b5d2aab2 --- /dev/null +++ b/examples/huggingface/_model-listing.php @@ -0,0 +1,39 @@ +setDescription('Lists all available models on HuggingFace') + ->addOption('provider', 'p', InputOption::VALUE_REQUIRED, 'Name of the inference provider to filter models by') + ->addOption('task', 't', InputOption::VALUE_REQUIRED, 'Name of the task to filter models by') + ->setCode(function (InputInterface $input, OutputInterface $output) { + $io = new SymfonyStyle($input, $output); + $io->title('HuggingFace Model Listing'); + + $provider = $input->getOption('provider'); + $task = $input->getOption('task'); + + $models = (new ApiClient())->models($provider, $task); + + if (0 === count($models)) { + $io->error('No models found for the given provider and task.'); + + return Command::FAILURE; + } + + $io->listing( + array_map(fn (Model $model) => $model->getName(), $models) + ); + + return Command::SUCCESS; + }) + ->run(); diff --git a/examples/huggingface/audio-classification.php b/examples/huggingface/audio-classification.php new file mode 100644 index 000000000..c691d76c1 --- /dev/null +++ b/examples/huggingface/audio-classification.php @@ -0,0 +1,25 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['HUGGINGFACE_KEY'])) { + echo 'Please set the HUGGINGFACE_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['HUGGINGFACE_KEY']); +$model = new Model('MIT/ast-finetuned-audioset-10-10-0.4593'); +$audio = Audio::fromFile(dirname(__DIR__, 2).'/tests/Fixture/audio.mp3'); + +$response = $platform->request($model, $audio, [ + 'task' => Task::AUDIO_CLASSIFICATION, +]); + +dump($response->getContent()); diff --git a/examples/huggingface/automatic-speech-recognition.php b/examples/huggingface/automatic-speech-recognition.php new file mode 100644 index 000000000..ab51e3226 --- /dev/null +++ b/examples/huggingface/automatic-speech-recognition.php @@ -0,0 +1,25 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['HUGGINGFACE_KEY'])) { + echo 'Please set the HUGGINGFACE_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['HUGGINGFACE_KEY']); +$model = new Model('openai/whisper-large-v3'); +$audio = Audio::fromFile(dirname(__DIR__, 2).'/tests/Fixture/audio.mp3'); + +$response = $platform->request($model, $audio, [ + 'task' => Task::AUTOMATIC_SPEECH_RECOGNITION, +]); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/huggingface/chat-completion.php b/examples/huggingface/chat-completion.php new file mode 100644 index 000000000..a324c1992 --- /dev/null +++ b/examples/huggingface/chat-completion.php @@ -0,0 +1,26 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['HUGGINGFACE_KEY'])) { + echo 'Please set the HUGGINGFACE_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['HUGGINGFACE_KEY']); +$model = new Model('HuggingFaceH4/zephyr-7b-beta'); + +$messages = new MessageBag(Message::ofUser('Hello, how are you doing today?')); +$response = $platform->request($model, $messages, [ + 'task' => Task::CHAT_COMPLETION, +]); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/huggingface/feature-extraction.php b/examples/huggingface/feature-extraction.php new file mode 100644 index 000000000..ac63149a7 --- /dev/null +++ b/examples/huggingface/feature-extraction.php @@ -0,0 +1,26 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['HUGGINGFACE_KEY'])) { + echo 'Please set the HUGGINGFACE_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['HUGGINGFACE_KEY']); +$model = new Model('thenlper/gte-large'); + +$response = $platform->request($model, 'Today is a sunny day and I will get some ice cream.', [ + 'task' => Task::FEATURE_EXTRACTION, +]); + +assert($response instanceof VectorResponse); + +echo 'Dimensions: '.$response->getContent()[0]->getDimensions().\PHP_EOL; diff --git a/examples/huggingface/fill-mask.php b/examples/huggingface/fill-mask.php new file mode 100644 index 000000000..5eac88c82 --- /dev/null +++ b/examples/huggingface/fill-mask.php @@ -0,0 +1,23 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['HUGGINGFACE_KEY'])) { + echo 'Please set the HUGGINGFACE_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['HUGGINGFACE_KEY']); +$model = new Model('FacebookAI/xlm-roberta-base'); + +$response = $platform->request($model, 'Hello I\'m a model.', [ + 'task' => Task::FILL_MASK, +]); + +dump($response->getContent()); diff --git a/examples/huggingface/image-classification.php b/examples/huggingface/image-classification.php new file mode 100644 index 000000000..c9f5a956d --- /dev/null +++ b/examples/huggingface/image-classification.php @@ -0,0 +1,25 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['HUGGINGFACE_KEY'])) { + echo 'Please set the HUGGINGFACE_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['HUGGINGFACE_KEY']); +$model = new Model('google/vit-base-patch16-224'); + +$image = Image::fromFile(dirname(__DIR__, 2).'/tests/Fixture/image.jpg'); +$response = $platform->request($model, $image, [ + 'task' => Task::IMAGE_CLASSIFICATION, +]); + +dump($response->getContent()); diff --git a/examples/huggingface/image-segmentation.php b/examples/huggingface/image-segmentation.php new file mode 100644 index 000000000..2e6952222 --- /dev/null +++ b/examples/huggingface/image-segmentation.php @@ -0,0 +1,25 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['HUGGINGFACE_KEY'])) { + echo 'Please set the HUGGINGFACE_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['HUGGINGFACE_KEY']); +$model = new Model('nvidia/segformer-b0-finetuned-ade-512-512'); + +$image = Image::fromFile(dirname(__DIR__, 2).'/tests/Fixture/image.jpg'); +$response = $platform->request($model, $image, [ + 'task' => Task::IMAGE_SEGMENTATION, +]); + +dump($response->getContent()); diff --git a/examples/huggingface/image-to-text.php b/examples/huggingface/image-to-text.php new file mode 100644 index 000000000..af0ddb237 --- /dev/null +++ b/examples/huggingface/image-to-text.php @@ -0,0 +1,25 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['HUGGINGFACE_KEY'])) { + echo 'Please set the HUGGINGFACE_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['HUGGINGFACE_KEY']); +$model = new Model('Salesforce/blip-image-captioning-base'); + +$image = Image::fromFile(dirname(__DIR__, 2).'/tests/Fixture/image.jpg'); +$response = $platform->request($model, $image, [ + 'task' => Task::IMAGE_TO_TEXT, +]); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/huggingface/object-detection.php b/examples/huggingface/object-detection.php new file mode 100644 index 000000000..e02f0b24a --- /dev/null +++ b/examples/huggingface/object-detection.php @@ -0,0 +1,25 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['HUGGINGFACE_KEY'])) { + echo 'Please set the HUGGINGFACE_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['HUGGINGFACE_KEY']); +$model = new Model('facebook/detr-resnet-50'); + +$image = Image::fromFile(dirname(__DIR__, 2).'/tests/Fixture/image.jpg'); +$response = $platform->request($model, $image, [ + 'task' => Task::OBJECT_DETECTION, +]); + +dump($response->getContent()); diff --git a/examples/huggingface/question-answering.php b/examples/huggingface/question-answering.php new file mode 100644 index 000000000..78d79fbfa --- /dev/null +++ b/examples/huggingface/question-answering.php @@ -0,0 +1,28 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['HUGGINGFACE_KEY'])) { + echo 'Please set the HUGGINGFACE_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['HUGGINGFACE_KEY']); +$model = new Model('deepset/roberta-base-squad2'); + +$input = [ + 'question' => 'What is the capital of France?', + 'context' => 'Paris is the capital and most populous city of France, with an estimated population of 2,175,601 residents as of 2018, in an area of more than 105 square kilometres.', +]; + +$response = $platform->request($model, $input, [ + 'task' => Task::QUESTION_ANSWERING, +]); + +dump($response->getContent()); diff --git a/examples/huggingface/sentence-similarity.php b/examples/huggingface/sentence-similarity.php new file mode 100644 index 000000000..f67411719 --- /dev/null +++ b/examples/huggingface/sentence-similarity.php @@ -0,0 +1,32 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['HUGGINGFACE_KEY'])) { + echo 'Please set the HUGGINGFACE_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['HUGGINGFACE_KEY']); +$model = new Model('sentence-transformers/all-MiniLM-L6-v2'); + +$input = [ + 'source_sentence' => 'That is a happy dog', + 'sentences' => [ + 'That is a happy canine', + 'That is a happy cat', + 'Today is a sunny day', + ], +]; + +$response = $platform->request($model, $input, [ + 'task' => Task::SENTENCE_SIMILARITY, +]); + +dump($response->getContent()); diff --git a/examples/huggingface/summarization.php b/examples/huggingface/summarization.php new file mode 100644 index 000000000..3db60942e --- /dev/null +++ b/examples/huggingface/summarization.php @@ -0,0 +1,33 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['HUGGINGFACE_KEY'])) { + echo 'Please set the HUGGINGFACE_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['HUGGINGFACE_KEY']); +$model = new Model('facebook/bart-large-cnn'); + +$longText = <<request($model, $longText, [ + 'task' => Task::SUMMARIZATION, +]); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/huggingface/table-question-answering.php b/examples/huggingface/table-question-answering.php new file mode 100644 index 000000000..469f2af71 --- /dev/null +++ b/examples/huggingface/table-question-answering.php @@ -0,0 +1,31 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['HUGGINGFACE_KEY'])) { + echo 'Please set the HUGGINGFACE_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['HUGGINGFACE_KEY']); +$model = new Model('microsoft/tapex-base'); + +$input = [ + 'query' => 'select year where city = beijing', + 'table' => [ + 'year' => [1896, 1900, 1904, 2004, 2008, 2012], + 'city' => ['athens', 'paris', 'st. louis', 'athens', 'beijing', 'london'], + ], +]; + +$response = $platform->request($model, $input, [ + 'task' => Task::TABLE_QUESTION_ANSWERING, +]); + +dump($response->getContent()); diff --git a/examples/huggingface/text-classification.php b/examples/huggingface/text-classification.php new file mode 100644 index 000000000..6e59742cf --- /dev/null +++ b/examples/huggingface/text-classification.php @@ -0,0 +1,23 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['HUGGINGFACE_KEY'])) { + echo 'Please set the HUGGINGFACE_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['HUGGINGFACE_KEY']); +$model = new Model('ProsusAI/finbert'); + +$response = $platform->request($model, 'I like you. I love you.', [ + 'task' => Task::TEXT_CLASSIFICATION, +]); + +dump($response->getContent()); diff --git a/examples/huggingface/text-generation.php b/examples/huggingface/text-generation.php new file mode 100644 index 000000000..210d3efc8 --- /dev/null +++ b/examples/huggingface/text-generation.php @@ -0,0 +1,23 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['HUGGINGFACE_KEY'])) { + echo 'Please set the HUGGINGFACE_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['HUGGINGFACE_KEY']); +$model = new Model('gpt2'); + +$response = $platform->request($model, 'The quick brown fox jumps over the lazy', [ + 'task' => Task::TEXT_GENERATION, +]); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/huggingface/text-to-image.php b/examples/huggingface/text-to-image.php new file mode 100644 index 000000000..f17960026 --- /dev/null +++ b/examples/huggingface/text-to-image.php @@ -0,0 +1,26 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['HUGGINGFACE_KEY'])) { + echo 'Please set the HUGGINGFACE_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['HUGGINGFACE_KEY']); +$model = new Model('black-forest-labs/FLUX.1-dev'); + +$response = $platform->request($model, 'Astronaut riding a horse', [ + 'task' => Task::TEXT_TO_IMAGE, +]); + +assert($response instanceof BinaryResponse); + +echo $response->toBase64().\PHP_EOL; diff --git a/examples/huggingface/token-classification.php b/examples/huggingface/token-classification.php new file mode 100644 index 000000000..3aaf8c892 --- /dev/null +++ b/examples/huggingface/token-classification.php @@ -0,0 +1,23 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['HUGGINGFACE_KEY'])) { + echo 'Please set the HUGGINGFACE_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['HUGGINGFACE_KEY']); +$model = new Model('dbmdz/bert-large-cased-finetuned-conll03-english'); + +$response = $platform->request($model, 'John Smith works at Microsoft in London.', [ + 'task' => Task::TOKEN_CLASSIFICATION, +]); + +dump($response->getContent()); diff --git a/examples/huggingface/translation.php b/examples/huggingface/translation.php new file mode 100644 index 000000000..a4edfae33 --- /dev/null +++ b/examples/huggingface/translation.php @@ -0,0 +1,25 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['HUGGINGFACE_KEY'])) { + echo 'Please set the HUGGINGFACE_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['HUGGINGFACE_KEY']); +$model = new Model('facebook/mbart-large-50-many-to-many-mmt'); + +$response = $platform->request($model, 'Меня зовут Вольфганг и я живу в Берлине', [ + 'task' => Task::TRANSLATION, + 'src_lang' => 'ru', + 'tgt_lang' => 'en', +]); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/huggingface/zero-shot-classification.php b/examples/huggingface/zero-shot-classification.php new file mode 100644 index 000000000..ee0069413 --- /dev/null +++ b/examples/huggingface/zero-shot-classification.php @@ -0,0 +1,25 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['HUGGINGFACE_KEY'])) { + echo 'Please set the HUGGINGFACE_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['HUGGINGFACE_KEY']); +$model = new Model('facebook/bart-large-mnli'); + +$text = 'Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!'; +$response = $platform->request($model, $text, [ + 'task' => Task::ZERO_SHOT_CLASSIFICATION, + 'candidate_labels' => ['refund', 'legal', 'faq'], +]); + +dump($response->getContent()); diff --git a/examples/mistral/chat.php b/examples/mistral/chat.php new file mode 100644 index 000000000..406caad29 --- /dev/null +++ b/examples/mistral/chat.php @@ -0,0 +1,27 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['MISTRAL_API_KEY'])) { + echo 'Please set the REPLICATE_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['MISTRAL_API_KEY']); +$model = new Mistral(); +$agent = new Agent($platform, $model); + +$messages = new MessageBag(Message::ofUser('What is the best French cheese?')); +$response = $agent->call($messages, [ + 'temperature' => 0.7, +]); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/mistral/embeddings.php b/examples/mistral/embeddings.php new file mode 100644 index 000000000..b0ff9239f --- /dev/null +++ b/examples/mistral/embeddings.php @@ -0,0 +1,28 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['MISTRAL_API_KEY'])) { + echo 'Please set the MISTRAL_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['MISTRAL_API_KEY']); +$model = new Embeddings(); + +$response = $platform->request($model, <<getContent()[0]->getDimensions().\PHP_EOL; diff --git a/examples/mistral/image.php b/examples/mistral/image.php new file mode 100644 index 000000000..ce369f578 --- /dev/null +++ b/examples/mistral/image.php @@ -0,0 +1,32 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['OPENAI_API_KEY'])) { + echo 'Please set the OPENAI_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['MISTRAL_API_KEY']); +$model = new Mistral(Mistral::MISTRAL_SMALL); +$agent = new Agent($platform, $model); + +$messages = new MessageBag( + Message::forSystem('You are an image analyzer bot that helps identify the content of images.'), + Message::ofUser( + 'Describe the image as a comedian would do it.', + Image::fromFile(dirname(__DIR__, 2).'/tests/Fixture/image.jpg'), + ), +); +$response = $agent->call($messages); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/mistral/stream.php b/examples/mistral/stream.php new file mode 100644 index 000000000..12a81d7f3 --- /dev/null +++ b/examples/mistral/stream.php @@ -0,0 +1,30 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['MISTRAL_API_KEY'])) { + echo 'Please set the REPLICATE_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['MISTRAL_API_KEY']); +$model = new Mistral(); +$agent = new Agent($platform, $model); + +$messages = new MessageBag(Message::ofUser('What is the eighth prime number?')); +$response = $agent->call($messages, [ + 'stream' => true, +]); + +foreach ($response->getContent() as $word) { + echo $word; +} +echo \PHP_EOL; diff --git a/examples/mistral/structured-output-math.php b/examples/mistral/structured-output-math.php new file mode 100644 index 000000000..358c68221 --- /dev/null +++ b/examples/mistral/structured-output-math.php @@ -0,0 +1,36 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['MISTRAL_API_KEY'])) { + echo 'Please set the MISTRAL_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['MISTRAL_API_KEY']); +$model = new Mistral(Mistral::MISTRAL_SMALL); +$serializer = new Serializer([new ObjectNormalizer()], [new JsonEncoder()]); + +$processor = new AgentProcessor(new ResponseFormatFactory(), $serializer); +$agent = new Agent($platform, $model, [$processor], [$processor]); +$messages = new MessageBag( + Message::forSystem('You are a helpful math tutor. Guide the user through the solution step by step.'), + Message::ofUser('how can I solve 8x + 7 = -23'), +); +$response = $agent->call($messages, ['output_structure' => MathReasoning::class]); + +dump($response->getContent()); diff --git a/examples/mistral/toolcall-stream.php b/examples/mistral/toolcall-stream.php new file mode 100644 index 000000000..9cd2ea544 --- /dev/null +++ b/examples/mistral/toolcall-stream.php @@ -0,0 +1,38 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['MISTRAL_API_KEY'])) { + echo 'Please set the REPLICATE_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['MISTRAL_API_KEY']); +$model = new Mistral(); + +$transcriber = new YouTubeTranscriber(HttpClient::create()); +$toolbox = Toolbox::create($transcriber); +$processor = new AgentProcessor($toolbox); +$agent = new Agent($platform, $model, [$processor], [$processor]); + +$messages = new MessageBag(Message::ofUser('Please summarize this video for me: https://www.youtube.com/watch?v=6uXW-ulpj0s')); +$response = $agent->call($messages, [ + 'stream' => true, +]); + +foreach ($response->getContent() as $word) { + echo $word; +} +echo \PHP_EOL; diff --git a/examples/mistral/toolcall.php b/examples/mistral/toolcall.php new file mode 100644 index 000000000..3c9e5b9ff --- /dev/null +++ b/examples/mistral/toolcall.php @@ -0,0 +1,31 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['MISTRAL_API_KEY'])) { + echo 'Please set the REPLICATE_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['MISTRAL_API_KEY']); +$model = new Mistral(); + +$toolbox = Toolbox::create(new Clock()); +$processor = new AgentProcessor($toolbox); +$agent = new Agent($platform, $model, [$processor], [$processor]); + +$messages = new MessageBag(Message::ofUser('What time is it?')); +$response = $agent->call($messages); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/ollama/chat-llama.php b/examples/ollama/chat-llama.php new file mode 100644 index 000000000..e3dc801c5 --- /dev/null +++ b/examples/ollama/chat-llama.php @@ -0,0 +1,28 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['OLLAMA_HOST_URL'])) { + echo 'Please set the OLLAMA_HOST_URL environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['OLLAMA_HOST_URL']); +$model = new Llama('llama3.2'); + +$agent = new Agent($platform, $model); +$messages = new MessageBag( + Message::forSystem('You are a helpful assistant.'), + Message::ofUser('Tina has one brother and one sister. How many sisters do Tina\'s siblings have?'), +); +$response = $agent->call($messages); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/openai/audio-input.php b/examples/openai/audio-input.php new file mode 100644 index 000000000..d7c4b237e --- /dev/null +++ b/examples/openai/audio-input.php @@ -0,0 +1,31 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['OPENAI_API_KEY'])) { + echo 'Please set the OPENAI_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['OPENAI_API_KEY']); +$model = new GPT(GPT::GPT_4O_AUDIO); + +$agent = new Agent($platform, $model); +$messages = new MessageBag( + Message::ofUser( + 'What is this recording about?', + Audio::fromFile(dirname(__DIR__, 2).'/fixtures/audio.mp3'), + ), +); +$response = $agent->call($messages); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/openai/audio-transcript.php b/examples/openai/audio-transcript.php new file mode 100644 index 000000000..20cdacab3 --- /dev/null +++ b/examples/openai/audio-transcript.php @@ -0,0 +1,22 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['OPENAI_API_KEY'])) { + echo 'Please set the OPENAI_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['OPENAI_API_KEY']); +$model = new Whisper(); +$file = Audio::fromFile(dirname(__DIR__, 2).'/fixtures/audio.mp3'); + +$response = $platform->request($model, $file); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/openai/chat-o1.php b/examples/openai/chat-o1.php new file mode 100644 index 000000000..5f7eb7bf0 --- /dev/null +++ b/examples/openai/chat-o1.php @@ -0,0 +1,37 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['OPENAI_API_KEY'])) { + echo 'Please set the OPENAI_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +if (empty($_ENV['RUN_EXPENSIVE_EXAMPLES']) || false === filter_var($_ENV['RUN_EXPENSIVE_EXAMPLES'], \FILTER_VALIDATE_BOOLEAN)) { + echo 'This example is marked as expensive and will not run unless RUN_EXPENSIVE_EXAMPLES is set to true.'.\PHP_EOL; + exit(134); +} + +$platform = PlatformFactory::create($_ENV['OPENAI_API_KEY']); +$model = new GPT(GPT::O1_PREVIEW); + +$prompt = <<call(new MessageBag(Message::ofUser($prompt))); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/openai/chat.php b/examples/openai/chat.php new file mode 100644 index 000000000..9d59648ed --- /dev/null +++ b/examples/openai/chat.php @@ -0,0 +1,32 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['OPENAI_API_KEY'])) { + echo 'Please set the OPENAI_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['OPENAI_API_KEY']); +$model = new GPT(GPT::GPT_4O_MINI, [ + 'temperature' => 0.5, // default options for the model +]); + +$agent = new Agent($platform, $model); +$messages = new MessageBag( + Message::forSystem('You are a pirate and you write funny.'), + Message::ofUser('What is the Symfony framework?'), +); +$response = $agent->call($messages, [ + 'max_tokens' => 500, // specific options just for this call +]); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/openai/embeddings.php b/examples/openai/embeddings.php new file mode 100644 index 000000000..a2418b8be --- /dev/null +++ b/examples/openai/embeddings.php @@ -0,0 +1,27 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['OPENAI_API_KEY'])) { + echo 'Please set the OPENAI_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['OPENAI_API_KEY']); +$embeddings = new Embeddings(); + +$response = $platform->request($embeddings, <<getContent()[0]->getDimensions().\PHP_EOL; diff --git a/examples/openai/image-input-binary.php b/examples/openai/image-input-binary.php new file mode 100644 index 000000000..e7d7c704f --- /dev/null +++ b/examples/openai/image-input-binary.php @@ -0,0 +1,32 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['OPENAI_API_KEY'])) { + echo 'Please set the OPENAI_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['OPENAI_API_KEY']); +$model = new GPT(GPT::GPT_4O_MINI); + +$agent = new Agent($platform, $model); +$messages = new MessageBag( + Message::forSystem('You are an image analyzer bot that helps identify the content of images.'), + Message::ofUser( + 'Describe the image as a comedian would do it.', + Image::fromFile(dirname(__DIR__, 2).'/fixtures/image.jpg'), + ), +); +$response = $agent->call($messages); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/openai/image-input-url.php b/examples/openai/image-input-url.php new file mode 100644 index 000000000..db0054dc3 --- /dev/null +++ b/examples/openai/image-input-url.php @@ -0,0 +1,32 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['OPENAI_API_KEY'])) { + echo 'Please set the OPENAI_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['OPENAI_API_KEY']); +$model = new GPT(GPT::GPT_4O_MINI); + +$agent = new Agent($platform, $model); +$messages = new MessageBag( + Message::forSystem('You are an image analyzer bot that helps identify the content of images.'), + Message::ofUser( + 'Describe the image as a comedian would do it.', + new ImageUrl('https://upload.wikimedia.org/wikipedia/commons/thumb/3/31/Webysther_20160423_-_Elephpant.svg/350px-Webysther_20160423_-_Elephpant.svg.png'), + ), +); +$response = $agent->call($messages); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/openai/image-output-dall-e-2.php b/examples/openai/image-output-dall-e-2.php new file mode 100644 index 000000000..62e93ea63 --- /dev/null +++ b/examples/openai/image-output-dall-e-2.php @@ -0,0 +1,28 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['OPENAI_API_KEY'])) { + echo 'Please set the OPENAI_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['OPENAI_API_KEY']); + +$response = $platform->request( + model: new DallE(), // Utilize Dall-E 2 version in default + input: 'A cartoon-style elephant with a long trunk and large ears.', + options: [ + 'response_format' => 'url', // Generate response as URL + 'n' => 2, // Generate multiple images for example + ], +); + +foreach ($response->getContent() as $index => $image) { + echo 'Image '.$index.': '.$image->url.\PHP_EOL; +} diff --git a/examples/openai/image-output-dall-e-3.php b/examples/openai/image-output-dall-e-3.php new file mode 100644 index 000000000..7bb60110a --- /dev/null +++ b/examples/openai/image-output-dall-e-3.php @@ -0,0 +1,32 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['OPENAI_API_KEY'])) { + echo 'Please set the OPENAI_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['OPENAI_API_KEY']); + +$response = $platform->request( + model: new DallE(name: DallE::DALL_E_3), + input: 'A cartoon-style elephant with a long trunk and large ears.', + options: [ + 'response_format' => 'url', // Generate response as URL + ], +); + +assert($response instanceof ImageResponse); + +echo 'Revised Prompt: '.$response->revisedPrompt.\PHP_EOL.\PHP_EOL; + +foreach ($response->getContent() as $index => $image) { + echo 'Image '.$index.': '.$image->url.\PHP_EOL; +} diff --git a/examples/openai/stream.php b/examples/openai/stream.php new file mode 100644 index 000000000..de7713e95 --- /dev/null +++ b/examples/openai/stream.php @@ -0,0 +1,33 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['OPENAI_API_KEY'])) { + echo 'Please set the OPENAI_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['OPENAI_API_KEY']); +$model = new GPT(GPT::GPT_4O_MINI); + +$agent = new Agent($platform, $model); +$messages = new MessageBag( + Message::forSystem('You are a thoughtful philosopher.'), + Message::ofUser('What is the purpose of an ant?'), +); +$response = $agent->call($messages, [ + 'stream' => true, // enable streaming of response text +]); + +foreach ($response->getContent() as $word) { + echo $word; +} +echo \PHP_EOL; diff --git a/examples/openai/structured-output-clock.php b/examples/openai/structured-output-clock.php new file mode 100644 index 000000000..99c44ebd4 --- /dev/null +++ b/examples/openai/structured-output-clock.php @@ -0,0 +1,50 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['OPENAI_API_KEY'])) { + echo 'Please set the OPENAI_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['OPENAI_API_KEY']); +$model = new GPT(GPT::GPT_4O_MINI); + +$clock = new Clock(new SymfonyClock()); +$toolbox = Toolbox::create($clock); +$toolProcessor = new ToolProcessor($toolbox); +$structuredOutputProcessor = new StructuredOutputProcessor(); +$agent = new Agent($platform, $model, [$toolProcessor, $structuredOutputProcessor], [$toolProcessor, $structuredOutputProcessor]); + +$messages = new MessageBag(Message::ofUser('What date and time is it?')); +$response = $agent->call($messages, ['response_format' => [ + 'type' => 'json_schema', + 'json_schema' => [ + 'name' => 'clock', + 'strict' => true, + 'schema' => [ + 'type' => 'object', + 'properties' => [ + 'date' => ['type' => 'string', 'description' => 'The current date in the format YYYY-MM-DD.'], + 'time' => ['type' => 'string', 'description' => 'The current time in the format HH:MM:SS.'], + ], + 'required' => ['date', 'time'], + 'additionalProperties' => false, + ], + ], +]]); + +dump($response->getContent()); diff --git a/examples/openai/structured-output-math.php b/examples/openai/structured-output-math.php new file mode 100644 index 000000000..435754852 --- /dev/null +++ b/examples/openai/structured-output-math.php @@ -0,0 +1,31 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['OPENAI_API_KEY'])) { + echo 'Please set the OPENAI_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['OPENAI_API_KEY']); +$model = new GPT(GPT::GPT_4O_MINI); + +$processor = new AgentProcessor(); +$agent = new Agent($platform, $model, [$processor], [$processor]); +$messages = new MessageBag( + Message::forSystem('You are a helpful math tutor. Guide the user through the solution step by step.'), + Message::ofUser('how can I solve 8x + 7 = -23'), +); +$response = $agent->call($messages, ['output_structure' => MathReasoning::class]); + +dump($response->getContent()); diff --git a/examples/openai/token-metadata.php b/examples/openai/token-metadata.php new file mode 100644 index 000000000..5bbdf6123 --- /dev/null +++ b/examples/openai/token-metadata.php @@ -0,0 +1,38 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['OPENAI_API_KEY'])) { + echo 'Please set the OPENAI_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['OPENAI_API_KEY']); +$model = new GPT(GPT::GPT_4O_MINI, [ + 'temperature' => 0.5, // default options for the model +]); + +$agent = new Agent($platform, $model, outputProcessors: [new TokenOutputProcessor()]); +$messages = new MessageBag( + Message::forSystem('You are a pirate and you write funny.'), + Message::ofUser('What is the Symfony framework?'), +); +$response = $agent->call($messages, [ + 'max_tokens' => 500, // specific options just for this call +]); + +$metadata = $response->getMetadata(); + +echo 'Utilized Tokens: '.$metadata['total_tokens'].\PHP_EOL; +echo '-- Prompt Tokens: '.$metadata['prompt_tokens'].\PHP_EOL; +echo '-- Completion Tokens: '.$metadata['completion_tokens'].\PHP_EOL; +echo 'Remaining Tokens: '.$metadata['remaining_tokens'].\PHP_EOL; diff --git a/examples/openai/toolcall-stream.php b/examples/openai/toolcall-stream.php new file mode 100644 index 000000000..4ddba513d --- /dev/null +++ b/examples/openai/toolcall-stream.php @@ -0,0 +1,42 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['OPENAI_API_KEY'])) { + echo 'Please set the OPENAI_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['OPENAI_API_KEY']); +$model = new GPT(GPT::GPT_4O_MINI); + +$wikipedia = new Wikipedia(HttpClient::create()); +$toolbox = Toolbox::create($wikipedia); +$processor = new AgentProcessor($toolbox); +$agent = new Agent($platform, $model, [$processor], [$processor]); +$messages = new MessageBag(Message::ofUser(<<call($messages, [ + 'stream' => true, // enable streaming of response text +]); + +foreach ($response->getContent() as $word) { + echo $word; +} + +echo \PHP_EOL; diff --git a/examples/openai/toolcall.php b/examples/openai/toolcall.php new file mode 100644 index 000000000..b42eaa731 --- /dev/null +++ b/examples/openai/toolcall.php @@ -0,0 +1,33 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['OPENAI_API_KEY'])) { + echo 'Please set the OPENAI_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['OPENAI_API_KEY']); +$model = new GPT(GPT::GPT_4O_MINI); + +$transcriber = new YouTubeTranscriber(HttpClient::create()); +$toolbox = Toolbox::create($transcriber); +$processor = new AgentProcessor($toolbox); +$agent = new Agent($platform, $model, [$processor], [$processor]); + +$messages = new MessageBag(Message::ofUser('Please summarize this video for me: https://www.youtube.com/watch?v=6uXW-ulpj0s')); +$response = $agent->call($messages); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/openrouter/chat-gemini.php b/examples/openrouter/chat-gemini.php new file mode 100644 index 000000000..2ef571daf --- /dev/null +++ b/examples/openrouter/chat-gemini.php @@ -0,0 +1,30 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['OPENROUTER_KEY'])) { + echo 'Please set the OPENROUTER_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['OPENROUTER_KEY']); +// In case free is running into 429 rate limit errors, you can use the paid model: +// $model = new Model('google/gemini-2.0-flash-lite-001'); +$model = new Model('google/gemini-2.0-flash-exp:free'); + +$agent = new Agent($platform, $model); +$messages = new MessageBag( + Message::forSystem('You are a helpful assistant.'), + Message::ofUser('Tina has one brother and one sister. How many sisters do Tina\'s siblings have?'), +); +$response = $agent->call($messages); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/parallel-chat-gpt.php b/examples/parallel-chat-gpt.php new file mode 100644 index 000000000..1292f6b0d --- /dev/null +++ b/examples/parallel-chat-gpt.php @@ -0,0 +1,36 @@ +loadEnv(dirname(__DIR__).'/.env'); + +if (empty($_ENV['OPENAI_API_KEY'])) { + echo 'Please set the OPENAI_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['OPENAI_API_KEY']); +$model = new GPT(GPT::GPT_4O_MINI, [ + 'temperature' => 0.5, // default options for the model +]); + +$messages = new MessageBag( + Message::forSystem('You will be given a letter and you answer with only the next letter of the alphabet.'), +); + +echo 'Initiating parallel calls to GPT on platform ...'.\PHP_EOL; +$responses = []; +foreach (range('A', 'D') as $letter) { + echo ' - Request for the letter '.$letter.' initiated.'.\PHP_EOL; + $responses[] = $platform->request($model, $messages->with(Message::ofUser($letter))); +} + +echo 'Waiting for the responses ...'.\PHP_EOL; +foreach ($responses as $response) { + echo 'Next Letter: '.$response->getContent().\PHP_EOL; +} diff --git a/examples/parallel-embeddings.php b/examples/parallel-embeddings.php new file mode 100644 index 000000000..54f6d1c6a --- /dev/null +++ b/examples/parallel-embeddings.php @@ -0,0 +1,32 @@ +loadEnv(dirname(__DIR__).'/.env'); + +if (empty($_ENV['OPENAI_API_KEY'])) { + echo 'Please set the OPENAI_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['OPENAI_API_KEY']); +$ada = new Embeddings(Embeddings::TEXT_ADA_002); +$small = new Embeddings(Embeddings::TEXT_3_SMALL); +$large = new Embeddings(Embeddings::TEXT_3_LARGE); + +echo 'Initiating parallel embeddings calls to platform ...'.\PHP_EOL; +$responses = []; +foreach (['ADA' => $ada, 'Small' => $small, 'Large' => $large] as $name => $model) { + echo ' - Request for model '.$name.' initiated.'.\PHP_EOL; + $responses[] = $platform->request($model, 'Hello, world!'); +} + +echo 'Waiting for the responses ...'.\PHP_EOL; +foreach ($responses as $response) { + assert($response instanceof VectorResponse); + echo 'Dimensions: '.$response->getContent()[0]->getDimensions().\PHP_EOL; +} diff --git a/examples/replicate/chat-llama.php b/examples/replicate/chat-llama.php new file mode 100644 index 000000000..b74a494aa --- /dev/null +++ b/examples/replicate/chat-llama.php @@ -0,0 +1,28 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['REPLICATE_API_KEY'])) { + echo 'Please set the REPLICATE_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['REPLICATE_API_KEY']); +$model = new Llama(); + +$agent = new Agent($platform, $model); +$messages = new MessageBag( + Message::forSystem('You are a helpful assistant.'), + Message::ofUser('Tina has one brother and one sister. How many sisters do Tina\'s siblings have?'), +); +$response = $agent->call($messages); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/store/mongodb-similarity-search.php b/examples/store/mongodb-similarity-search.php new file mode 100644 index 000000000..44c1b2b87 --- /dev/null +++ b/examples/store/mongodb-similarity-search.php @@ -0,0 +1,74 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['OPENAI_API_KEY']) || empty($_ENV['MONGODB_URI'])) { + echo 'Please set OPENAI_API_KEY and MONGODB_URI environment variables.'.\PHP_EOL; + exit(1); +} + +// initialize the store +$store = new Store( + client: new MongoDBClient($_ENV['MONGODB_URI']), + databaseName: 'my-database', + collectionName: 'my-collection', + indexName: 'my-index', + vectorFieldName: 'vector', +); + +// our data +$movies = [ + ['title' => 'Inception', 'description' => 'A skilled thief is given a chance at redemption if he can successfully perform inception, the act of planting an idea in someone\'s subconscious.', 'director' => 'Christopher Nolan'], + ['title' => 'The Matrix', 'description' => 'A hacker discovers the world he lives in is a simulated reality and joins a rebellion to overthrow its controllers.', 'director' => 'The Wachowskis'], + ['title' => 'The Godfather', 'description' => 'The aging patriarch of an organized crime dynasty transfers control of his empire to his reluctant son.', 'director' => 'Francis Ford Coppola'], +]; + +// create embeddings and documents +foreach ($movies as $movie) { + $documents[] = new TextDocument( + id: Uuid::v4(), + content: 'Title: '.$movie['title'].\PHP_EOL.'Director: '.$movie['director'].\PHP_EOL.'Description: '.$movie['description'], + metadata: new Metadata($movie), + ); +} + +// create embeddings for documents +$platform = PlatformFactory::create($_ENV['OPENAI_API_KEY']); +$embedder = new Embedder($platform, $embeddings = new Embeddings(), $store); +$embedder->embed($documents); + +// initialize the index +$store->initialize(); + +$model = new GPT(GPT::GPT_4O_MINI); + +$similaritySearch = new SimilaritySearch($platform, $embeddings, $store); +$toolbox = Toolbox::create($similaritySearch); +$processor = new AgentProcessor($toolbox); +$agent = new Agent($platform, $model, [$processor], [$processor]); + +$messages = new MessageBag( + Message::forSystem('Please answer all user questions only using SimilaritySearch function.'), + Message::ofUser('Which movie fits the theme of the mafia?') +); +$response = $agent->call($messages); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/store/pinecone-similarity-search.php b/examples/store/pinecone-similarity-search.php new file mode 100644 index 000000000..06c335471 --- /dev/null +++ b/examples/store/pinecone-similarity-search.php @@ -0,0 +1,65 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['OPENAI_API_KEY']) || empty($_ENV['PINECONE_API_KEY']) || empty($_ENV['PINECONE_HOST'])) { + echo 'Please set OPENAI_API_KEY, PINECONE_API_KEY and PINECONE_HOST environment variables.'.\PHP_EOL; + exit(1); +} + +// initialize the store +$store = new Store(Pinecone::client($_ENV['PINECONE_API_KEY'], $_ENV['PINECONE_HOST'])); + +// our data +$movies = [ + ['title' => 'Inception', 'description' => 'A skilled thief is given a chance at redemption if he can successfully perform inception, the act of planting an idea in someone\'s subconscious.', 'director' => 'Christopher Nolan'], + ['title' => 'The Matrix', 'description' => 'A hacker discovers the world he lives in is a simulated reality and joins a rebellion to overthrow its controllers.', 'director' => 'The Wachowskis'], + ['title' => 'The Godfather', 'description' => 'The aging patriarch of an organized crime dynasty transfers control of his empire to his reluctant son.', 'director' => 'Francis Ford Coppola'], +]; + +// create embeddings and documents +foreach ($movies as $movie) { + $documents[] = new TextDocument( + id: Uuid::v4(), + content: 'Title: '.$movie['title'].\PHP_EOL.'Director: '.$movie['director'].\PHP_EOL.'Description: '.$movie['description'], + metadata: new Metadata($movie), + ); +} + +// create embeddings for documents +$platform = PlatformFactory::create($_ENV['OPENAI_API_KEY']); +$embedder = new Embedder($platform, $embeddings = new Embeddings(), $store); +$embedder->embed($documents); + +$model = new GPT(GPT::GPT_4O_MINI); + +$similaritySearch = new SimilaritySearch($platform, $embeddings, $store); +$toolbox = Toolbox::create($similaritySearch); +$processor = new AgentProcessor($toolbox); +$agent = new Agent($platform, $model, [$processor], [$processor]); + +$messages = new MessageBag( + Message::forSystem('Please answer all user questions only using SimilaritySearch function.'), + Message::ofUser('Which movie fits the theme of the mafia?') +); +$response = $agent->call($messages); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/toolbox/brave.php b/examples/toolbox/brave.php new file mode 100644 index 000000000..2ec0cff5a --- /dev/null +++ b/examples/toolbox/brave.php @@ -0,0 +1,35 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['OPENAI_API_KEY']) || empty($_ENV['BRAVE_API_KEY'])) { + echo 'Please set the OPENAI_API_KEY and BRAVE_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} +$platform = PlatformFactory::create($_ENV['OPENAI_API_KEY']); +$model = new GPT(GPT::GPT_4O_MINI); + +$httpClient = HttpClient::create(); +$brave = new Brave($httpClient, $_ENV['BRAVE_API_KEY']); +$crawler = new Crawler($httpClient); +$toolbox = Toolbox::create($brave, $crawler); +$processor = new AgentProcessor($toolbox); +$agent = new Agent($platform, $model, [$processor], [$processor]); + +$messages = new MessageBag(Message::ofUser('What was the latest game result of Dallas Cowboys?')); +$response = $agent->call($messages); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/toolbox/clock.php b/examples/toolbox/clock.php new file mode 100644 index 000000000..262590a73 --- /dev/null +++ b/examples/toolbox/clock.php @@ -0,0 +1,34 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['OPENAI_API_KEY'])) { + echo 'Please set the OPENAI_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['OPENAI_API_KEY']); +$model = new GPT(GPT::GPT_4O_MINI); + +$metadataFactory = (new MemoryToolFactory()) + ->addTool(Clock::class, 'clock', 'Get the current date and time', 'now'); +$toolbox = new Toolbox($metadataFactory, [new Clock()]); +$processor = new AgentProcessor($toolbox); +$agent = new Agent($platform, $model, [$processor], [$processor]); + +$messages = new MessageBag(Message::ofUser('What date and time is it?')); +$response = $agent->call($messages); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/toolbox/serpapi.php b/examples/toolbox/serpapi.php new file mode 100644 index 000000000..178e1c165 --- /dev/null +++ b/examples/toolbox/serpapi.php @@ -0,0 +1,32 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['OPENAI_API_KEY']) || empty($_ENV['SERP_API_KEY'])) { + echo 'Please set the OPENAI_API_KEY and SERP_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} +$platform = PlatformFactory::create($_ENV['OPENAI_API_KEY']); +$model = new GPT(GPT::GPT_4O_MINI); + +$serpApi = new SerpApi(HttpClient::create(), $_ENV['SERP_API_KEY']); +$toolbox = Toolbox::create($serpApi); +$processor = new AgentProcessor($toolbox); +$agent = new Agent($platform, $model, [$processor], [$processor]); + +$messages = new MessageBag(Message::ofUser('Who is the current chancellor of Germany?')); +$response = $agent->call($messages); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/toolbox/tavily.php b/examples/toolbox/tavily.php new file mode 100644 index 000000000..ca52d17ce --- /dev/null +++ b/examples/toolbox/tavily.php @@ -0,0 +1,32 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['OPENAI_API_KEY']) || empty($_ENV['TAVILY_API_KEY'])) { + echo 'Please set the OPENAI_API_KEY and TAVILY_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} +$platform = PlatformFactory::create($_ENV['OPENAI_API_KEY']); +$model = new GPT(GPT::GPT_4O_MINI); + +$tavily = new Tavily(HttpClient::create(), $_ENV['TAVILY_API_KEY']); +$toolbox = Toolbox::create($tavily); +$processor = new AgentProcessor($toolbox); +$agent = new Agent($platform, $model, [$processor], [$processor]); + +$messages = new MessageBag(Message::ofUser('What was the latest game result of Dallas Cowboys?')); +$response = $agent->call($messages); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/toolbox/weather-event.php b/examples/toolbox/weather-event.php new file mode 100644 index 000000000..6202df26f --- /dev/null +++ b/examples/toolbox/weather-event.php @@ -0,0 +1,46 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['OPENAI_API_KEY'])) { + echo 'Please set the OPENAI_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['OPENAI_API_KEY']); +$model = new GPT(GPT::GPT_4O_MINI); + +$openMeteo = new OpenMeteo(HttpClient::create()); +$toolbox = Toolbox::create($openMeteo); +$eventDispatcher = new EventDispatcher(); +$processor = new AgentProcessor($toolbox, eventDispatcher: $eventDispatcher); +$agent = new Agent($platform, $model, [$processor], [$processor]); + +// Add tool call result listener to enforce chain exits direct with structured response for weather tools +$eventDispatcher->addListener(ToolCallsExecuted::class, function (ToolCallsExecuted $event): void { + foreach ($event->toolCallResults as $toolCallResult) { + if (str_starts_with($toolCallResult->toolCall->name, 'weather_')) { + $event->response = new ObjectResponse($toolCallResult->result); + } + } +}); + +$messages = new MessageBag(Message::ofUser('How is the weather currently in Berlin?')); +$response = $agent->call($messages); + +dump($response->getContent()); diff --git a/examples/transformers/text-generation.php b/examples/transformers/text-generation.php new file mode 100644 index 000000000..a7cce6e30 --- /dev/null +++ b/examples/transformers/text-generation.php @@ -0,0 +1,26 @@ +request($model, 'How many continents are there in the world?', [ + 'task' => Task::Text2TextGeneration, +]); + +echo $response->getContent().\PHP_EOL; diff --git a/examples/voyage/embeddings.php b/examples/voyage/embeddings.php new file mode 100644 index 000000000..f2c9be7e4 --- /dev/null +++ b/examples/voyage/embeddings.php @@ -0,0 +1,27 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['VOYAGE_API_KEY'])) { + echo 'Please set the VOYAGE_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['VOYAGE_API_KEY']); +$embeddings = new Voyage(); + +$response = $platform->request($embeddings, <<getContent()[0]->getDimensions().\PHP_EOL; diff --git a/fixtures/SomeStructure.php b/fixtures/SomeStructure.php new file mode 100644 index 000000000..e04ad2932 --- /dev/null +++ b/fixtures/SomeStructure.php @@ -0,0 +1,17 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Fixture; + +final class SomeStructure +{ + public string $some; +} diff --git a/fixtures/StructuredOutput/MathReasoning.php b/fixtures/StructuredOutput/MathReasoning.php new file mode 100644 index 000000000..9d49f9842 --- /dev/null +++ b/fixtures/StructuredOutput/MathReasoning.php @@ -0,0 +1,24 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Fixtures\StructuredOutput; + +final class MathReasoning +{ + /** + * @param Step[] $steps + */ + public function __construct( + public array $steps, + public string $finalAnswer, + ) { + } +} diff --git a/fixtures/StructuredOutput/Step.php b/fixtures/StructuredOutput/Step.php new file mode 100644 index 000000000..8ef356149 --- /dev/null +++ b/fixtures/StructuredOutput/Step.php @@ -0,0 +1,21 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Fixtures\StructuredOutput; + +final class Step +{ + public function __construct( + public string $explanation, + public string $output, + ) { + } +} diff --git a/fixtures/StructuredOutput/User.php b/fixtures/StructuredOutput/User.php new file mode 100644 index 000000000..f006b936c --- /dev/null +++ b/fixtures/StructuredOutput/User.php @@ -0,0 +1,24 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Fixtures\StructuredOutput; + +final class User +{ + public int $id; + /** + * @var string The name of the user in lowercase + */ + public string $name; + public \DateTimeInterface $createdAt; + public bool $isActive; + public ?int $age = null; +} diff --git a/fixtures/StructuredOutput/UserWithConstructor.php b/fixtures/StructuredOutput/UserWithConstructor.php new file mode 100644 index 000000000..ceb7fa65d --- /dev/null +++ b/fixtures/StructuredOutput/UserWithConstructor.php @@ -0,0 +1,27 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Fixtures\StructuredOutput; + +final class UserWithConstructor +{ + /** + * @param string $name The name of the user in lowercase + */ + public function __construct( + public int $id, + public string $name, + public \DateTimeInterface $createdAt, + public bool $isActive, + public ?int $age = null, + ) { + } +} diff --git a/fixtures/Tool/ToolArray.php b/fixtures/Tool/ToolArray.php new file mode 100644 index 000000000..23ff813e1 --- /dev/null +++ b/fixtures/Tool/ToolArray.php @@ -0,0 +1,27 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Fixtures\Tool; + +use Symfony\AI\Agent\Toolbox\Attribute\AsTool; + +#[AsTool('tool_no_params', 'A tool without parameters')] +final class ToolArray +{ + /** + * @param string[] $urls + * @param list $ids + */ + public function __invoke(array $urls, array $ids): string + { + return 'Hello world!'; + } +} diff --git a/fixtures/Tool/ToolException.php b/fixtures/Tool/ToolException.php new file mode 100644 index 000000000..f53aef7bc --- /dev/null +++ b/fixtures/Tool/ToolException.php @@ -0,0 +1,23 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Fixtures\Tool; + +use Symfony\AI\Agent\Toolbox\Attribute\AsTool; + +#[AsTool('tool_exception', description: 'This tool is broken', method: 'bar')] +final class ToolException +{ + public function bar(): string + { + throw new \Exception('Tool error.'); + } +} diff --git a/fixtures/Tool/ToolMisconfigured.php b/fixtures/Tool/ToolMisconfigured.php new file mode 100644 index 000000000..25abd67ac --- /dev/null +++ b/fixtures/Tool/ToolMisconfigured.php @@ -0,0 +1,23 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Fixtures\Tool; + +use Symfony\AI\Agent\Toolbox\Attribute\AsTool; + +#[AsTool('tool_misconfigured', description: 'This tool is misconfigured, see method', method: 'foo')] +final class ToolMisconfigured +{ + public function bar(): string + { + return 'Wrong Config Attribute'; + } +} diff --git a/fixtures/Tool/ToolMultiple.php b/fixtures/Tool/ToolMultiple.php new file mode 100644 index 000000000..861b33ba2 --- /dev/null +++ b/fixtures/Tool/ToolMultiple.php @@ -0,0 +1,36 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Fixtures\Tool; + +use Symfony\AI\Agent\Toolbox\Attribute\AsTool; + +#[AsTool('tool_hello_world', 'Function to say hello', method: 'hello')] +#[AsTool('tool_required_params', 'Function to say a number', method: 'bar')] +final class ToolMultiple +{ + /** + * @param string $world The world to say hello to + */ + public function hello(string $world): string + { + return \sprintf('Hello "%s".', $world); + } + + /** + * @param string $text The text given to the tool + * @param int $number A number given to the tool + */ + public function bar(string $text, int $number): string + { + return \sprintf('%s says "%d".', $text, $number); + } +} diff --git a/fixtures/Tool/ToolNoAttribute1.php b/fixtures/Tool/ToolNoAttribute1.php new file mode 100644 index 000000000..c3fbc64a6 --- /dev/null +++ b/fixtures/Tool/ToolNoAttribute1.php @@ -0,0 +1,24 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Fixtures\Tool; + +final class ToolNoAttribute1 +{ + /** + * @param string $name the name of the person + * @param int $years the age of the person + */ + public function __invoke(string $name, int $years): string + { + return \sprintf('Happy Birthday, %s! You are %d years old.', $name, $years); + } +} diff --git a/fixtures/Tool/ToolNoAttribute2.php b/fixtures/Tool/ToolNoAttribute2.php new file mode 100644 index 000000000..28345652d --- /dev/null +++ b/fixtures/Tool/ToolNoAttribute2.php @@ -0,0 +1,32 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Fixtures\Tool; + +final class ToolNoAttribute2 +{ + /** + * @param int $id the ID of the product + * @param int $amount the number of products + */ + public function buy(int $id, int $amount): string + { + return \sprintf('You bought %d of product %d.', $amount, $id); + } + + /** + * @param string $orderId the ID of the order + */ + public function cancel(string $orderId): string + { + return \sprintf('You canceled order %s.', $orderId); + } +} diff --git a/fixtures/Tool/ToolNoParams.php b/fixtures/Tool/ToolNoParams.php new file mode 100644 index 000000000..1d7e2586d --- /dev/null +++ b/fixtures/Tool/ToolNoParams.php @@ -0,0 +1,23 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Fixtures\Tool; + +use Symfony\AI\Agent\Toolbox\Attribute\AsTool; + +#[AsTool('tool_no_params', 'A tool without parameters')] +final class ToolNoParams +{ + public function __invoke(): string + { + return 'Hello world!'; + } +} diff --git a/fixtures/Tool/ToolOptionalParam.php b/fixtures/Tool/ToolOptionalParam.php new file mode 100644 index 000000000..281c5f4e0 --- /dev/null +++ b/fixtures/Tool/ToolOptionalParam.php @@ -0,0 +1,27 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Fixtures\Tool; + +use Symfony\AI\Agent\Toolbox\Attribute\AsTool; + +#[AsTool('tool_optional_param', 'A tool with one optional parameter', method: 'bar')] +final class ToolOptionalParam +{ + /** + * @param string $text The text given to the tool + * @param int $number A number given to the tool + */ + public function bar(string $text, int $number = 3): string + { + return \sprintf('%s says "%d".', $text, $number); + } +} diff --git a/fixtures/Tool/ToolRequiredParams.php b/fixtures/Tool/ToolRequiredParams.php new file mode 100644 index 000000000..3b56204fb --- /dev/null +++ b/fixtures/Tool/ToolRequiredParams.php @@ -0,0 +1,27 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Fixtures\Tool; + +use Symfony\AI\Agent\Toolbox\Attribute\AsTool; + +#[AsTool('tool_required_params', 'A tool with required parameters', method: 'bar')] +final class ToolRequiredParams +{ + /** + * @param string $text The text given to the tool + * @param int $number A number given to the tool + */ + public function bar(string $text, int $number): string + { + return \sprintf('%s says "%d".', $text, $number); + } +} diff --git a/fixtures/Tool/ToolWithToolParameterAttribute.php b/fixtures/Tool/ToolWithToolParameterAttribute.php new file mode 100644 index 000000000..6c6803980 --- /dev/null +++ b/fixtures/Tool/ToolWithToolParameterAttribute.php @@ -0,0 +1,71 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Fixtures\Tool; + +use Symfony\AI\Agent\Toolbox\Attribute\AsTool; +use Symfony\AI\Platform\Contract\JsonSchema\Attribute\With; + +#[AsTool('tool_with_ToolParameter_attribute', 'A tool which has a parameter with described with #[ToolParameter] attribute')] +final class ToolWithToolParameterAttribute +{ + /** + * @param string $animal The animal given to the tool + * @param int $numberOfArticles The number of articles given to the tool + * @param string $infoEmail The info email given to the tool + * @param string $locales The locales given to the tool + * @param string $text The text given to the tool + * @param int $number The number given to the tool + * @param array $products The products given to the tool + * @param string $shippingAddress The shipping address given to the tool + */ + public function __invoke( + #[With(enum: ['dog', 'cat', 'bird'])] + string $animal, + #[With(const: 42)] + int $numberOfArticles, + #[With(const: 'info@example.de')] + string $infoEmail, + #[With(const: ['de', 'en'])] + string $locales, + #[With( + pattern: '^[a-zA-Z]+$', + minLength: 1, + maxLength: 10, + )] + string $text, + #[With( + minimum: 1, + maximum: 10, + multipleOf: 2, + exclusiveMinimum: 1, + exclusiveMaximum: 10, + )] + int $number, + #[With( + minItems: 1, + maxItems: 10, + uniqueItems: true, + minContains: 1, + maxContains: 10, + )] + array $products, + #[With( + required: true, + minProperties: 1, + maxProperties: 10, + dependentRequired: true, + )] + string $shippingAddress, + ): string { + return 'Hello, World!'; + } +} diff --git a/fixtures/Tool/ToolWithoutDocs.php b/fixtures/Tool/ToolWithoutDocs.php new file mode 100644 index 000000000..29a172558 --- /dev/null +++ b/fixtures/Tool/ToolWithoutDocs.php @@ -0,0 +1,23 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Fixtures\Tool; + +use Symfony\AI\Agent\Toolbox\Attribute\AsTool; + +#[AsTool('tool_without_docs', 'A tool with required parameters', method: 'bar')] +final class ToolWithoutDocs +{ + public function bar(string $text, int $number): string + { + return \sprintf('%s says "%d".', $text, $number); + } +} diff --git a/fixtures/Tool/ToolWrong.php b/fixtures/Tool/ToolWrong.php new file mode 100644 index 000000000..7297b4f13 --- /dev/null +++ b/fixtures/Tool/ToolWrong.php @@ -0,0 +1,24 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Fixtures\Tool; + +final class ToolWrong +{ + /** + * @param string $text The text given to the tool + * @param int $number A number given to the tool + */ + public function bar(string $text, int $number): string + { + return \sprintf('%s says "%d".', $text, $number); + } +} diff --git a/fixtures/audio.mp3 b/fixtures/audio.mp3 new file mode 100644 index 000000000..509aa0fc4 Binary files /dev/null and b/fixtures/audio.mp3 differ diff --git a/fixtures/document.pdf b/fixtures/document.pdf new file mode 100644 index 000000000..00172f266 Binary files /dev/null and b/fixtures/document.pdf differ diff --git a/fixtures/image.jpg b/fixtures/image.jpg new file mode 100644 index 000000000..bae677adf Binary files /dev/null and b/fixtures/image.jpg differ diff --git a/src/agent/.gitattributes b/src/agent/.gitattributes new file mode 100644 index 000000000..ec8c01802 --- /dev/null +++ b/src/agent/.gitattributes @@ -0,0 +1,6 @@ +/.github export-ignore +/tests export-ignore +.gitattributes export-ignore +.gitignore export-ignore +phpstan.dist.neon export-ignore +phpunit.xml.dist export-ignore diff --git a/src/agent/.github/PULL_REQUEST_TEMPLATE.md b/src/agent/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 000000000..fcb87228a --- /dev/null +++ b/src/agent/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,8 @@ +Please do not submit any Pull Requests here. They will be closed. +--- + +Please submit your PR here instead: +https://github.com/symfony/ai + +This repository is what we call a "subtree split": a read-only subset of that main repository. +We're looking forward to your PR there! diff --git a/src/agent/.github/workflows/close-pull-request.yml b/src/agent/.github/workflows/close-pull-request.yml new file mode 100644 index 000000000..207153fd5 --- /dev/null +++ b/src/agent/.github/workflows/close-pull-request.yml @@ -0,0 +1,20 @@ +name: Close Pull Request + +on: + pull_request_target: + types: [opened] + +jobs: + run: + runs-on: ubuntu-latest + steps: + - uses: superbrothers/close-pull-request@v3 + with: + comment: | + Thanks for your Pull Request! We love contributions. + + However, you should instead open your PR on the main repository: + https://github.com/symfony/ai + + This repository is what we call a "subtree split": a read-only subset of that main repository. + We're looking forward to your PR there! diff --git a/src/agent/.gitignore b/src/agent/.gitignore new file mode 100644 index 000000000..f43db636b --- /dev/null +++ b/src/agent/.gitignore @@ -0,0 +1,3 @@ +composer.lock +vendor +.phpunit.cache diff --git a/src/agent/LICENSE b/src/agent/LICENSE new file mode 100644 index 000000000..bc38d714e --- /dev/null +++ b/src/agent/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2025-present Fabien Potencier + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is furnished +to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/src/agent/composer.json b/src/agent/composer.json new file mode 100644 index 000000000..daa458508 --- /dev/null +++ b/src/agent/composer.json @@ -0,0 +1,69 @@ +{ + "name": "symfony/ai-agent", + "type": "library", + "description": "PHP library for building agentic applications.", + "keywords": [ + "ai", + "llm", + "agent" + ], + "license": "MIT", + "authors": [ + { + "name": "Christopher Hertel", + "email": "mail@christopher-hertel.de" + }, + { + "name": "Oskar Stark", + "email": "oskarstark@googlemail.com" + } + ], + "require": { + "php": ">=8.2", + "ext-fileinfo": "*", + "oskarstark/enum-helper": "^1.5", + "phpdocumentor/reflection-docblock": "^5.4", + "phpstan/phpdoc-parser": "^2.1", + "psr/cache": "^3.0", + "psr/log": "^3.0", + "symfony/clock": "^6.4 || ^7.1", + "symfony/http-client": "^6.4 || ^7.1", + "symfony/property-access": "^6.4 || ^7.1", + "symfony/property-info": "^6.4 || ^7.1", + "symfony/serializer": "^6.4 || ^7.1", + "symfony/type-info": "^7.2.3", + "symfony/uid": "^6.4 || ^7.1", + "webmozart/assert": "^1.11" + }, + "require-dev": { + "phpstan/phpstan": "^2.0", + "phpstan/phpstan-symfony": "^2.0", + "phpstan/phpstan-webmozart-assert": "^2.0", + "phpunit/phpunit": "^11.5", + "symfony/console": "^6.4 || ^7.1", + "symfony/css-selector": "^6.4 || ^7.1", + "symfony/dom-crawler": "^6.4 || ^7.1", + "symfony/dotenv": "^6.4 || ^7.1", + "symfony/event-dispatcher": "^6.4 || ^7.1", + "symfony/finder": "^6.4 || ^7.1", + "symfony/process": "^6.4 || ^7.1", + "symfony/var-dumper": "^6.4 || ^7.1" + }, + "suggest": { + "symfony/css-selector": "For using the YouTube transcription tool.", + "symfony/dom-crawler": "For using the YouTube transcription tool." + }, + "config": { + "sort-packages": true + }, + "autoload": { + "psr-4": { + "Symfony\\AI\\Agent\\": "src/" + } + }, + "autoload-dev": { + "psr-4": { + "Symfony\\AI\\Agent\\Tests\\": "tests/" + } + } +} diff --git a/src/agent/phpstan.dist.neon b/src/agent/phpstan.dist.neon new file mode 100644 index 000000000..8cc83f644 --- /dev/null +++ b/src/agent/phpstan.dist.neon @@ -0,0 +1,10 @@ +includes: + - vendor/phpstan/phpstan-webmozart-assert/extension.neon + - vendor/phpstan/phpstan-symfony/extension.neon + +parameters: + level: 6 + paths: + - src/ + - tests/ + diff --git a/src/agent/src/Agent.php b/src/agent/src/Agent.php new file mode 100644 index 000000000..1e9eb1566 --- /dev/null +++ b/src/agent/src/Agent.php @@ -0,0 +1,121 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent; + +use Psr\Log\LoggerInterface; +use Psr\Log\NullLogger; +use Symfony\AI\Agent\Exception\InvalidArgumentException; +use Symfony\AI\Agent\Exception\MissingModelSupportException; +use Symfony\AI\Agent\Exception\RuntimeException; +use Symfony\AI\Platform\Capability; +use Symfony\AI\Platform\Message\MessageBagInterface; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\PlatformInterface; +use Symfony\AI\Platform\Response\AsyncResponse; +use Symfony\AI\Platform\Response\ResponseInterface; +use Symfony\Contracts\HttpClient\Exception\ClientExceptionInterface; +use Symfony\Contracts\HttpClient\Exception\HttpExceptionInterface; + +/** + * @author Christopher Hertel + */ +final readonly class Agent implements AgentInterface +{ + /** + * @var InputProcessorInterface[] + */ + private array $inputProcessors; + + /** + * @var OutputProcessorInterface[] + */ + private array $outputProcessors; + + /** + * @param InputProcessorInterface[] $inputProcessors + * @param OutputProcessorInterface[] $outputProcessors + */ + public function __construct( + private PlatformInterface $platform, + private Model $model, + iterable $inputProcessors = [], + iterable $outputProcessors = [], + private LoggerInterface $logger = new NullLogger(), + ) { + $this->inputProcessors = $this->initializeProcessors($inputProcessors, InputProcessorInterface::class); + $this->outputProcessors = $this->initializeProcessors($outputProcessors, OutputProcessorInterface::class); + } + + /** + * @param array $options + */ + public function call(MessageBagInterface $messages, array $options = []): ResponseInterface + { + $input = new Input($this->model, $messages, $options); + array_map(fn (InputProcessorInterface $processor) => $processor->processInput($input), $this->inputProcessors); + + $model = $input->model; + $messages = $input->messages; + $options = $input->getOptions(); + + if ($messages->containsAudio() && !$model->supports(Capability::INPUT_AUDIO)) { + throw MissingModelSupportException::forAudioInput($model::class); + } + + if ($messages->containsImage() && !$model->supports(Capability::INPUT_IMAGE)) { + throw MissingModelSupportException::forImageInput($model::class); + } + + try { + $response = $this->platform->request($model, $messages, $options); + + if ($response instanceof AsyncResponse) { + $response = $response->unwrap(); + } + } catch (ClientExceptionInterface $e) { + $message = $e->getMessage(); + $content = $e->getResponse()->toArray(false); + + $this->logger->debug($message, $content); + + throw new InvalidArgumentException('' === $message ? 'Invalid request to model or platform' : $message, previous: $e); + } catch (HttpExceptionInterface $e) { + throw new RuntimeException('Failed to request model', previous: $e); + } + + $output = new Output($model, $response, $messages, $options); + array_map(fn (OutputProcessorInterface $processor) => $processor->processOutput($output), $this->outputProcessors); + + return $output->response; + } + + /** + * @param InputProcessorInterface[]|OutputProcessorInterface[] $processors + * @param class-string $interface + * + * @return InputProcessorInterface[]|OutputProcessorInterface[] + */ + private function initializeProcessors(iterable $processors, string $interface): array + { + foreach ($processors as $processor) { + if (!$processor instanceof $interface) { + throw new InvalidArgumentException(\sprintf('Processor %s must implement %s interface.', $processor::class, $interface)); + } + + if ($processor instanceof AgentAwareInterface) { + $processor->setAgent($this); + } + } + + return $processors instanceof \Traversable ? iterator_to_array($processors) : $processors; + } +} diff --git a/src/agent/src/AgentAwareInterface.php b/src/agent/src/AgentAwareInterface.php new file mode 100644 index 000000000..843da2c74 --- /dev/null +++ b/src/agent/src/AgentAwareInterface.php @@ -0,0 +1,20 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent; + +/** + * @author Christopher Hertel + */ +interface AgentAwareInterface +{ + public function setAgent(AgentInterface $agent): void; +} diff --git a/src/agent/src/AgentAwareTrait.php b/src/agent/src/AgentAwareTrait.php new file mode 100644 index 000000000..56b826843 --- /dev/null +++ b/src/agent/src/AgentAwareTrait.php @@ -0,0 +1,25 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent; + +/** + * @author Christopher Hertel + */ +trait AgentAwareTrait +{ + private AgentInterface $agent; + + public function setAgent(AgentInterface $agent): void + { + $this->agent = $agent; + } +} diff --git a/src/agent/src/AgentInterface.php b/src/agent/src/AgentInterface.php new file mode 100644 index 000000000..d205836dd --- /dev/null +++ b/src/agent/src/AgentInterface.php @@ -0,0 +1,26 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent; + +use Symfony\AI\Platform\Message\MessageBagInterface; +use Symfony\AI\Platform\Response\ResponseInterface; + +/** + * @author Denis Zunke + */ +interface AgentInterface +{ + /** + * @param array $options + */ + public function call(MessageBagInterface $messages, array $options = []): ResponseInterface; +} diff --git a/src/agent/src/Exception/ExceptionInterface.php b/src/agent/src/Exception/ExceptionInterface.php new file mode 100644 index 000000000..606960fc2 --- /dev/null +++ b/src/agent/src/Exception/ExceptionInterface.php @@ -0,0 +1,19 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Exception; + +/** + * @author Oskar Stark + */ +interface ExceptionInterface extends \Throwable +{ +} diff --git a/src/agent/src/Exception/InvalidArgumentException.php b/src/agent/src/Exception/InvalidArgumentException.php new file mode 100644 index 000000000..71e15909f --- /dev/null +++ b/src/agent/src/Exception/InvalidArgumentException.php @@ -0,0 +1,19 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Exception; + +/** + * @author Oskar Stark + */ +class InvalidArgumentException extends \InvalidArgumentException implements ExceptionInterface +{ +} diff --git a/src/agent/src/Exception/LogicException.php b/src/agent/src/Exception/LogicException.php new file mode 100644 index 000000000..3eff060d9 --- /dev/null +++ b/src/agent/src/Exception/LogicException.php @@ -0,0 +1,19 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Exception; + +/** + * @author Oskar Stark + */ +class LogicException extends \LogicException implements ExceptionInterface +{ +} diff --git a/src/agent/src/Exception/MissingModelSupportException.php b/src/agent/src/Exception/MissingModelSupportException.php new file mode 100644 index 000000000..eb43df1ec --- /dev/null +++ b/src/agent/src/Exception/MissingModelSupportException.php @@ -0,0 +1,43 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Exception; + +/** + * @author Christopher Hertel + */ +final class MissingModelSupportException extends RuntimeException +{ + private function __construct(string $model, string $support) + { + parent::__construct(\sprintf('Model "%s" does not support "%s".', $model, $support)); + } + + public static function forToolCalling(string $model): self + { + return new self($model, 'tool calling'); + } + + public static function forAudioInput(string $model): self + { + return new self($model, 'audio input'); + } + + public static function forImageInput(string $model): self + { + return new self($model, 'image input'); + } + + public static function forStructuredOutput(string $model): self + { + return new self($model, 'structured output'); + } +} diff --git a/src/agent/src/Exception/RuntimeException.php b/src/agent/src/Exception/RuntimeException.php new file mode 100644 index 000000000..c0a2ee8db --- /dev/null +++ b/src/agent/src/Exception/RuntimeException.php @@ -0,0 +1,19 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Exception; + +/** + * @author Oskar Stark + */ +class RuntimeException extends \RuntimeException implements ExceptionInterface +{ +} diff --git a/src/agent/src/Input.php b/src/agent/src/Input.php new file mode 100644 index 000000000..7b0bd905e --- /dev/null +++ b/src/agent/src/Input.php @@ -0,0 +1,47 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent; + +use Symfony\AI\Platform\Message\MessageBagInterface; +use Symfony\AI\Platform\Model; + +/** + * @author Christopher Hertel + */ +final class Input +{ + /** + * @param array $options + */ + public function __construct( + public Model $model, + public MessageBagInterface $messages, + private array $options, + ) { + } + + /** + * @return array + */ + public function getOptions(): array + { + return $this->options; + } + + /** + * @param array $options + */ + public function setOptions(array $options): void + { + $this->options = $options; + } +} diff --git a/src/agent/src/InputProcessor/ModelOverrideInputProcessor.php b/src/agent/src/InputProcessor/ModelOverrideInputProcessor.php new file mode 100644 index 000000000..7dafea175 --- /dev/null +++ b/src/agent/src/InputProcessor/ModelOverrideInputProcessor.php @@ -0,0 +1,38 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\InputProcessor; + +use Symfony\AI\Agent\Exception\InvalidArgumentException; +use Symfony\AI\Agent\Input; +use Symfony\AI\Agent\InputProcessorInterface; +use Symfony\AI\Platform\Model; + +/** + * @author Christopher Hertel + */ +final class ModelOverrideInputProcessor implements InputProcessorInterface +{ + public function processInput(Input $input): void + { + $options = $input->getOptions(); + + if (!\array_key_exists('model', $options)) { + return; + } + + if (!$options['model'] instanceof Model) { + throw new InvalidArgumentException(\sprintf('Option "model" must be an instance of %s.', Model::class)); + } + + $input->model = $options['model']; + } +} diff --git a/src/agent/src/InputProcessor/SystemPromptInputProcessor.php b/src/agent/src/InputProcessor/SystemPromptInputProcessor.php new file mode 100644 index 000000000..9e9c9c370 --- /dev/null +++ b/src/agent/src/InputProcessor/SystemPromptInputProcessor.php @@ -0,0 +1,74 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\InputProcessor; + +use Psr\Log\LoggerInterface; +use Psr\Log\NullLogger; +use Symfony\AI\Agent\Input; +use Symfony\AI\Agent\InputProcessorInterface; +use Symfony\AI\Agent\Toolbox\ToolboxInterface; +use Symfony\AI\Platform\Message\Message; +use Symfony\AI\Platform\Tool\Tool; + +/** + * @author Christopher Hertel + */ +final readonly class SystemPromptInputProcessor implements InputProcessorInterface +{ + /** + * @param \Stringable|string $systemPrompt the system prompt to prepend to the input messages + * @param ToolboxInterface|null $toolbox the tool box to be used to append the tool definitions to the system prompt + */ + public function __construct( + private \Stringable|string $systemPrompt, + private ?ToolboxInterface $toolbox = null, + private LoggerInterface $logger = new NullLogger(), + ) { + } + + public function processInput(Input $input): void + { + $messages = $input->messages; + + if (null !== $messages->getSystemMessage()) { + $this->logger->debug('Skipping system prompt injection since MessageBag already contains a system message.'); + + return; + } + + $message = (string) $this->systemPrompt; + + if ($this->toolbox instanceof ToolboxInterface + && [] !== $this->toolbox->getTools() + ) { + $this->logger->debug('Append tool definitions to system prompt.'); + + $tools = implode(\PHP_EOL.\PHP_EOL, array_map( + fn (Tool $tool) => <<name} + {$tool->description} + TOOL, + $this->toolbox->getTools() + )); + + $message = <<systemPrompt} + + # Available tools + + {$tools} + PROMPT; + } + + $input->messages = $messages->prepend(Message::forSystem($message)); + } +} diff --git a/src/agent/src/InputProcessorInterface.php b/src/agent/src/InputProcessorInterface.php new file mode 100644 index 000000000..fc0868bb4 --- /dev/null +++ b/src/agent/src/InputProcessorInterface.php @@ -0,0 +1,20 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent; + +/** + * @author Christopher Hertel + */ +interface InputProcessorInterface +{ + public function processInput(Input $input): void; +} diff --git a/src/agent/src/Output.php b/src/agent/src/Output.php new file mode 100644 index 000000000..dffd7381c --- /dev/null +++ b/src/agent/src/Output.php @@ -0,0 +1,33 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent; + +use Symfony\AI\Platform\Message\MessageBagInterface; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\Response\ResponseInterface; + +/** + * @author Christopher Hertel + */ +final class Output +{ + /** + * @param array $options + */ + public function __construct( + public readonly Model $model, + public ResponseInterface $response, + public readonly MessageBagInterface $messages, + public readonly array $options, + ) { + } +} diff --git a/src/agent/src/OutputProcessorInterface.php b/src/agent/src/OutputProcessorInterface.php new file mode 100644 index 000000000..a6ad499c9 --- /dev/null +++ b/src/agent/src/OutputProcessorInterface.php @@ -0,0 +1,20 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent; + +/** + * @author Christopher Hertel + */ +interface OutputProcessorInterface +{ + public function processOutput(Output $output): void; +} diff --git a/src/agent/src/StructuredOutput/AgentProcessor.php b/src/agent/src/StructuredOutput/AgentProcessor.php new file mode 100644 index 000000000..50a95a292 --- /dev/null +++ b/src/agent/src/StructuredOutput/AgentProcessor.php @@ -0,0 +1,94 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\StructuredOutput; + +use Symfony\AI\Agent\Exception\InvalidArgumentException; +use Symfony\AI\Agent\Exception\MissingModelSupportException; +use Symfony\AI\Agent\Input; +use Symfony\AI\Agent\InputProcessorInterface; +use Symfony\AI\Agent\Output; +use Symfony\AI\Agent\OutputProcessorInterface; +use Symfony\AI\Platform\Capability; +use Symfony\AI\Platform\Response\ObjectResponse; +use Symfony\Component\PropertyInfo\Extractor\PhpDocExtractor; +use Symfony\Component\PropertyInfo\PropertyInfoExtractor; +use Symfony\Component\Serializer\Encoder\JsonEncoder; +use Symfony\Component\Serializer\Normalizer\ArrayDenormalizer; +use Symfony\Component\Serializer\Normalizer\ObjectNormalizer; +use Symfony\Component\Serializer\Serializer; +use Symfony\Component\Serializer\SerializerInterface; + +/** + * @author Christopher Hertel + */ +final class AgentProcessor implements InputProcessorInterface, OutputProcessorInterface +{ + private string $outputStructure; + + public function __construct( + private readonly ResponseFormatFactoryInterface $responseFormatFactory = new ResponseFormatFactory(), + private ?SerializerInterface $serializer = null, + ) { + if (null === $this->serializer) { + $propertyInfo = new PropertyInfoExtractor([], [new PhpDocExtractor()]); + $normalizers = [new ObjectNormalizer(propertyTypeExtractor: $propertyInfo), new ArrayDenormalizer()]; + $this->serializer = new Serializer($normalizers, [new JsonEncoder()]); + } + } + + public function processInput(Input $input): void + { + $options = $input->getOptions(); + + if (!isset($options['output_structure'])) { + return; + } + + if (!$input->model->supports(Capability::OUTPUT_STRUCTURED)) { + throw MissingModelSupportException::forStructuredOutput($input->model::class); + } + + if (true === ($options['stream'] ?? false)) { + throw new InvalidArgumentException('Streamed responses are not supported for structured output'); + } + + $options['response_format'] = $this->responseFormatFactory->create($options['output_structure']); + + $this->outputStructure = $options['output_structure']; + unset($options['output_structure']); + + $input->setOptions($options); + } + + public function processOutput(Output $output): void + { + $options = $output->options; + + if ($output->response instanceof ObjectResponse) { + return; + } + + if (!isset($options['response_format'])) { + return; + } + + if (!isset($this->outputStructure)) { + $output->response = new ObjectResponse(json_decode($output->response->getContent(), true)); + + return; + } + + $output->response = new ObjectResponse( + $this->serializer->deserialize($output->response->getContent(), $this->outputStructure, 'json') + ); + } +} diff --git a/src/agent/src/StructuredOutput/ResponseFormatFactory.php b/src/agent/src/StructuredOutput/ResponseFormatFactory.php new file mode 100644 index 000000000..a81811925 --- /dev/null +++ b/src/agent/src/StructuredOutput/ResponseFormatFactory.php @@ -0,0 +1,39 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\StructuredOutput; + +use Symfony\AI\Platform\Contract\JsonSchema\Factory; + +use function Symfony\Component\String\u; + +/** + * @author Christopher Hertel + */ +final readonly class ResponseFormatFactory implements ResponseFormatFactoryInterface +{ + public function __construct( + private Factory $schemaFactory = new Factory(), + ) { + } + + public function create(string $responseClass): array + { + return [ + 'type' => 'json_schema', + 'json_schema' => [ + 'name' => u($responseClass)->afterLast('\\')->toString(), + 'schema' => $this->schemaFactory->buildProperties($responseClass), + 'strict' => true, + ], + ]; + } +} diff --git a/src/agent/src/StructuredOutput/ResponseFormatFactoryInterface.php b/src/agent/src/StructuredOutput/ResponseFormatFactoryInterface.php new file mode 100644 index 000000000..ab28b1091 --- /dev/null +++ b/src/agent/src/StructuredOutput/ResponseFormatFactoryInterface.php @@ -0,0 +1,32 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\StructuredOutput; + +/** + * @author Oskar Stark + */ +interface ResponseFormatFactoryInterface +{ + /** + * @param class-string $responseClass + * + * @return array{ + * type: 'json_schema', + * json_schema: array{ + * name: string, + * schema: array, + * strict: true, + * } + * } + */ + public function create(string $responseClass): array; +} diff --git a/src/agent/src/Toolbox/AgentProcessor.php b/src/agent/src/Toolbox/AgentProcessor.php new file mode 100644 index 000000000..b3c626e78 --- /dev/null +++ b/src/agent/src/Toolbox/AgentProcessor.php @@ -0,0 +1,122 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Toolbox; + +use Symfony\AI\Agent\AgentAwareInterface; +use Symfony\AI\Agent\AgentAwareTrait; +use Symfony\AI\Agent\Exception\MissingModelSupportException; +use Symfony\AI\Agent\Input; +use Symfony\AI\Agent\InputProcessorInterface; +use Symfony\AI\Agent\Output; +use Symfony\AI\Agent\OutputProcessorInterface; +use Symfony\AI\Agent\Toolbox\Event\ToolCallsExecuted; +use Symfony\AI\Agent\Toolbox\StreamResponse as ToolboxStreamResponse; +use Symfony\AI\Platform\Capability; +use Symfony\AI\Platform\Message\AssistantMessage; +use Symfony\AI\Platform\Message\Message; +use Symfony\AI\Platform\Response\ResponseInterface; +use Symfony\AI\Platform\Response\StreamResponse as GenericStreamResponse; +use Symfony\AI\Platform\Response\ToolCallResponse; +use Symfony\AI\Platform\Tool\Tool; +use Symfony\Contracts\EventDispatcher\EventDispatcherInterface; + +/** + * @author Christopher Hertel + */ +final class AgentProcessor implements InputProcessorInterface, OutputProcessorInterface, AgentAwareInterface +{ + use AgentAwareTrait; + + public function __construct( + private readonly ToolboxInterface $toolbox, + private readonly ToolResultConverter $resultConverter = new ToolResultConverter(), + private readonly ?EventDispatcherInterface $eventDispatcher = null, + ) { + } + + public function processInput(Input $input): void + { + if (!$input->model->supports(Capability::TOOL_CALLING)) { + throw MissingModelSupportException::forToolCalling($input->model::class); + } + + $toolMap = $this->toolbox->getTools(); + if ([] === $toolMap) { + return; + } + + $options = $input->getOptions(); + // only filter tool map if list of strings is provided as option + if (isset($options['tools']) && $this->isFlatStringArray($options['tools'])) { + $toolMap = array_values(array_filter($toolMap, fn (Tool $tool) => \in_array($tool->name, $options['tools'], true))); + } + + $options['tools'] = $toolMap; + $input->setOptions($options); + } + + public function processOutput(Output $output): void + { + if ($output->response instanceof GenericStreamResponse) { + $output->response = new ToolboxStreamResponse( + $output->response->getContent(), + $this->handleToolCallsCallback($output), + ); + + return; + } + + if (!$output->response instanceof ToolCallResponse) { + return; + } + + $output->response = $this->handleToolCallsCallback($output)($output->response); + } + + /** + * @param array $tools + */ + private function isFlatStringArray(array $tools): bool + { + return array_reduce($tools, fn (bool $carry, mixed $item) => $carry && \is_string($item), true); + } + + private function handleToolCallsCallback(Output $output): \Closure + { + return function (ToolCallResponse $response, ?AssistantMessage $streamedAssistantResponse = null) use ($output): ResponseInterface { + $messages = clone $output->messages; + + if (null !== $streamedAssistantResponse && '' !== $streamedAssistantResponse->content) { + $messages->add($streamedAssistantResponse); + } + + do { + $toolCalls = $response->getContent(); + $messages->add(Message::ofAssistant(toolCalls: $toolCalls)); + + $results = []; + foreach ($toolCalls as $toolCall) { + $result = $this->toolbox->execute($toolCall); + $results[] = new ToolCallResult($toolCall, $result); + $messages->add(Message::ofToolCall($toolCall, $this->resultConverter->convert($result))); + } + + $event = new ToolCallsExecuted(...$results); + $this->eventDispatcher?->dispatch($event); + + $response = $event->hasResponse() ? $event->response : $this->agent->call($messages, $output->options); + } while ($response instanceof ToolCallResponse); + + return $response; + }; + } +} diff --git a/src/agent/src/Toolbox/Attribute/AsTool.php b/src/agent/src/Toolbox/Attribute/AsTool.php new file mode 100644 index 000000000..04811ac33 --- /dev/null +++ b/src/agent/src/Toolbox/Attribute/AsTool.php @@ -0,0 +1,26 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Toolbox\Attribute; + +/** + * @author Christopher Hertel + */ +#[\Attribute(\Attribute::TARGET_CLASS | \Attribute::IS_REPEATABLE)] +final readonly class AsTool +{ + public function __construct( + public string $name, + public string $description, + public string $method = '__invoke', + ) { + } +} diff --git a/src/agent/src/Toolbox/Event/ToolCallsExecuted.php b/src/agent/src/Toolbox/Event/ToolCallsExecuted.php new file mode 100644 index 000000000..4101b8d0a --- /dev/null +++ b/src/agent/src/Toolbox/Event/ToolCallsExecuted.php @@ -0,0 +1,37 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Toolbox\Event; + +use Symfony\AI\Agent\Toolbox\ToolCallResult; +use Symfony\AI\Platform\Response\ResponseInterface; + +/** + * @author Christopher Hertel + */ +final class ToolCallsExecuted +{ + /** + * @var ToolCallResult[] + */ + public readonly array $toolCallResults; + public ResponseInterface $response; + + public function __construct(ToolCallResult ...$toolCallResults) + { + $this->toolCallResults = $toolCallResults; + } + + public function hasResponse(): bool + { + return isset($this->response); + } +} diff --git a/src/agent/src/Toolbox/Exception/ExceptionInterface.php b/src/agent/src/Toolbox/Exception/ExceptionInterface.php new file mode 100644 index 000000000..bbb590084 --- /dev/null +++ b/src/agent/src/Toolbox/Exception/ExceptionInterface.php @@ -0,0 +1,21 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Toolbox\Exception; + +use Symfony\AI\Agent\Exception\ExceptionInterface as BaseExceptionInterface; + +/** + * @author Christopher Hertel + */ +interface ExceptionInterface extends BaseExceptionInterface +{ +} diff --git a/src/agent/src/Toolbox/Exception/ToolConfigurationException.php b/src/agent/src/Toolbox/Exception/ToolConfigurationException.php new file mode 100644 index 000000000..0e39c7f30 --- /dev/null +++ b/src/agent/src/Toolbox/Exception/ToolConfigurationException.php @@ -0,0 +1,25 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Toolbox\Exception; + +use Symfony\AI\Agent\Exception\InvalidArgumentException; + +/** + * @author Christopher Hertel + */ +final class ToolConfigurationException extends InvalidArgumentException implements ExceptionInterface +{ + public static function invalidMethod(string $toolClass, string $methodName, \ReflectionException $previous): self + { + return new self(\sprintf('Method "%s" not found in tool "%s".', $methodName, $toolClass), previous: $previous); + } +} diff --git a/src/agent/src/Toolbox/Exception/ToolException.php b/src/agent/src/Toolbox/Exception/ToolException.php new file mode 100644 index 000000000..08e496f6c --- /dev/null +++ b/src/agent/src/Toolbox/Exception/ToolException.php @@ -0,0 +1,31 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Toolbox\Exception; + +use Symfony\AI\Agent\Exception\InvalidArgumentException; +use Symfony\AI\Agent\Toolbox\Attribute\AsTool; + +/** + * @author Christopher Hertel + */ +final class ToolException extends InvalidArgumentException implements ExceptionInterface +{ + public static function invalidReference(mixed $reference): self + { + return new self(\sprintf('The reference "%s" is not a valid tool.', $reference)); + } + + public static function missingAttribute(string $className): self + { + return new self(\sprintf('The class "%s" is not a tool, please add %s attribute.', $className, AsTool::class)); + } +} diff --git a/src/agent/src/Toolbox/Exception/ToolExecutionException.php b/src/agent/src/Toolbox/Exception/ToolExecutionException.php new file mode 100644 index 000000000..3577020b5 --- /dev/null +++ b/src/agent/src/Toolbox/Exception/ToolExecutionException.php @@ -0,0 +1,30 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Toolbox\Exception; + +use Symfony\AI\Platform\Response\ToolCall; + +/** + * @author Christopher Hertel + */ +final class ToolExecutionException extends \RuntimeException implements ExceptionInterface +{ + public ?ToolCall $toolCall = null; + + public static function executionFailed(ToolCall $toolCall, \Throwable $previous): self + { + $exception = new self(\sprintf('Execution of tool "%s" failed with error: %s', $toolCall->name, $previous->getMessage()), previous: $previous); + $exception->toolCall = $toolCall; + + return $exception; + } +} diff --git a/src/agent/src/Toolbox/Exception/ToolNotFoundException.php b/src/agent/src/Toolbox/Exception/ToolNotFoundException.php new file mode 100644 index 000000000..8e1fb9009 --- /dev/null +++ b/src/agent/src/Toolbox/Exception/ToolNotFoundException.php @@ -0,0 +1,36 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Toolbox\Exception; + +use Symfony\AI\Platform\Response\ToolCall; +use Symfony\AI\Platform\Tool\ExecutionReference; + +/** + * @author Christopher Hertel + */ +final class ToolNotFoundException extends \RuntimeException implements ExceptionInterface +{ + public ?ToolCall $toolCall = null; + + public static function notFoundForToolCall(ToolCall $toolCall): self + { + $exception = new self(\sprintf('Tool not found for call: %s.', $toolCall->name)); + $exception->toolCall = $toolCall; + + return $exception; + } + + public static function notFoundForReference(ExecutionReference $reference): self + { + return new self(\sprintf('Tool not found for reference: %s::%s.', $reference->class, $reference->method)); + } +} diff --git a/src/agent/src/Toolbox/FaultTolerantToolbox.php b/src/agent/src/Toolbox/FaultTolerantToolbox.php new file mode 100644 index 000000000..f372ab1ec --- /dev/null +++ b/src/agent/src/Toolbox/FaultTolerantToolbox.php @@ -0,0 +1,48 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Toolbox; + +use Symfony\AI\Agent\Toolbox\Exception\ToolExecutionException; +use Symfony\AI\Agent\Toolbox\Exception\ToolNotFoundException; +use Symfony\AI\Platform\Response\ToolCall; +use Symfony\AI\Platform\Tool\Tool; + +/** + * Catches exceptions thrown by the inner tool box and returns error messages for the LLM instead. + * + * @author Christopher Hertel + */ +final readonly class FaultTolerantToolbox implements ToolboxInterface +{ + public function __construct( + private ToolboxInterface $innerToolbox, + ) { + } + + public function getTools(): array + { + return $this->innerToolbox->getTools(); + } + + public function execute(ToolCall $toolCall): mixed + { + try { + return $this->innerToolbox->execute($toolCall); + } catch (ToolExecutionException $e) { + return \sprintf('An error occurred while executing tool "%s".', $e->toolCall->name); + } catch (ToolNotFoundException) { + $names = array_map(fn (Tool $metadata) => $metadata->name, $this->getTools()); + + return \sprintf('Tool "%s" was not found, please use one of these: %s', $toolCall->name, implode(', ', $names)); + } + } +} diff --git a/src/agent/src/Toolbox/StreamResponse.php b/src/agent/src/Toolbox/StreamResponse.php new file mode 100644 index 000000000..475dbb600 --- /dev/null +++ b/src/agent/src/Toolbox/StreamResponse.php @@ -0,0 +1,43 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Toolbox; + +use Symfony\AI\Platform\Message\Message; +use Symfony\AI\Platform\Response\BaseResponse; +use Symfony\AI\Platform\Response\ToolCallResponse; + +/** + * @author Denis Zunke + */ +final class StreamResponse extends BaseResponse +{ + public function __construct( + private readonly \Generator $generator, + private readonly \Closure $handleToolCallsCallback, + ) { + } + + public function getContent(): \Generator + { + $streamedResponse = ''; + foreach ($this->generator as $value) { + if ($value instanceof ToolCallResponse) { + yield from ($this->handleToolCallsCallback)($value, Message::ofAssistant($streamedResponse))->getContent(); + + break; + } + + $streamedResponse .= $value; + yield $value; + } + } +} diff --git a/src/agent/src/Toolbox/Tool/Agent.php b/src/agent/src/Toolbox/Tool/Agent.php new file mode 100644 index 000000000..cea81ca13 --- /dev/null +++ b/src/agent/src/Toolbox/Tool/Agent.php @@ -0,0 +1,40 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Toolbox\Tool; + +use Symfony\AI\Agent\AgentInterface; +use Symfony\AI\Platform\Message\Message; +use Symfony\AI\Platform\Message\MessageBag; +use Symfony\AI\Platform\Response\TextResponse; + +/** + * @author Christopher Hertel + */ +final readonly class Agent +{ + public function __construct( + private AgentInterface $agent, + ) { + } + + /** + * @param string $message the message to pass to the agent + */ + public function __invoke(string $message): string + { + $response = $this->agent->call(new MessageBag(Message::ofUser($message))); + + \assert($response instanceof TextResponse); + + return $response->getContent(); + } +} diff --git a/src/agent/src/Toolbox/Tool/Brave.php b/src/agent/src/Toolbox/Tool/Brave.php new file mode 100644 index 000000000..4a25d3006 --- /dev/null +++ b/src/agent/src/Toolbox/Tool/Brave.php @@ -0,0 +1,70 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Toolbox\Tool; + +use Symfony\AI\Agent\Toolbox\Attribute\AsTool; +use Symfony\AI\Platform\Contract\JsonSchema\Attribute\With; +use Symfony\Contracts\HttpClient\HttpClientInterface; + +/** + * @author Christopher Hertel + */ +#[AsTool('brave_search', 'Tool that searches the web using Brave Search')] +final readonly class Brave +{ + /** + * @param array $options See https://api-dashboard.search.brave.com/app/documentation/web-search/query#WebSearchAPIQueryParameters + */ + public function __construct( + private HttpClientInterface $httpClient, + #[\SensitiveParameter] + private string $apiKey, + private array $options = [], + ) { + } + + /** + * @param string $query the search query term + * @param int $count The number of search results returned in response. + * Combine this parameter with offset to paginate search results. + * @param int $offset The number of search results to skip before returning results. + * In order to paginate results use this parameter together with count. + * + * @return array + */ + public function __invoke( + #[With(maximum: 500)] + string $query, + int $count = 20, + #[With(minimum: 0, maximum: 9)] + int $offset = 0, + ): array { + $response = $this->httpClient->request('GET', 'https://api.search.brave.com/res/v1/web/search', [ + 'headers' => ['X-Subscription-Token' => $this->apiKey], + 'query' => array_merge($this->options, [ + 'q' => $query, + 'count' => $count, + 'offset' => $offset, + ]), + ]); + + $data = $response->toArray(); + + return array_map(static function (array $result) { + return ['title' => $result['title'], 'description' => $result['description'], 'url' => $result['url']]; + }, $data['web']['results'] ?? []); + } +} diff --git a/src/agent/src/Toolbox/Tool/Clock.php b/src/agent/src/Toolbox/Tool/Clock.php new file mode 100644 index 000000000..562b3ccf8 --- /dev/null +++ b/src/agent/src/Toolbox/Tool/Clock.php @@ -0,0 +1,37 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Toolbox\Tool; + +use Symfony\AI\Agent\Toolbox\Attribute\AsTool; +use Symfony\Component\Clock\Clock as SymfonyClock; +use Symfony\Component\Clock\ClockInterface; + +/** + * @author Christopher Hertel + */ +#[AsTool('clock', description: 'Provides the current date and time.')] +final readonly class Clock +{ + public function __construct( + private ClockInterface $clock = new SymfonyClock(), + ) { + } + + public function __invoke(): string + { + return \sprintf( + 'Current date is %s (YYYY-MM-DD) and the time is %s (HH:MM:SS).', + $this->clock->now()->format('Y-m-d'), + $this->clock->now()->format('H:i:s'), + ); + } +} diff --git a/src/agent/src/Toolbox/Tool/Crawler.php b/src/agent/src/Toolbox/Tool/Crawler.php new file mode 100644 index 000000000..96e809eee --- /dev/null +++ b/src/agent/src/Toolbox/Tool/Crawler.php @@ -0,0 +1,42 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Toolbox\Tool; + +use Symfony\AI\Agent\Exception\RuntimeException; +use Symfony\AI\Agent\Toolbox\Attribute\AsTool; +use Symfony\Component\DomCrawler\Crawler as DomCrawler; +use Symfony\Contracts\HttpClient\HttpClientInterface; + +/** + * @author Christopher Hertel + */ +#[AsTool('crawler', 'A tool that crawls one page of a website and returns the visible text of it.')] +final readonly class Crawler +{ + public function __construct( + private HttpClientInterface $httpClient, + ) { + if (!class_exists(DomCrawler::class)) { + throw new RuntimeException('The DomCrawler component is not installed. Please install it using "composer require symfony/dom-crawler".'); + } + } + + /** + * @param string $url the URL of the page to crawl + */ + public function __invoke(string $url): string + { + $response = $this->httpClient->request('GET', $url); + + return (new DomCrawler($response->getContent()))->filter('body')->text(); + } +} diff --git a/src/agent/src/Toolbox/Tool/OpenMeteo.php b/src/agent/src/Toolbox/Tool/OpenMeteo.php new file mode 100644 index 000000000..7e0adef82 --- /dev/null +++ b/src/agent/src/Toolbox/Tool/OpenMeteo.php @@ -0,0 +1,132 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Toolbox\Tool; + +use Symfony\AI\Agent\Toolbox\Attribute\AsTool; +use Symfony\AI\Platform\Contract\JsonSchema\Attribute\With; +use Symfony\Contracts\HttpClient\HttpClientInterface; + +/** + * @author Christopher Hertel + */ +#[AsTool(name: 'weather_current', description: 'get current weather for a location', method: 'current')] +#[AsTool(name: 'weather_forecast', description: 'get weather forecast for a location', method: 'forecast')] +final readonly class OpenMeteo +{ + private const WMO_CODES = [ + 0 => 'Clear', + 1 => 'Mostly Clear', + 2 => 'Partly Cloudy', + 3 => 'Overcast', + 45 => 'Fog', + 48 => 'Icy Fog', + 51 => 'Light Drizzle', + 53 => 'Drizzle', + 55 => 'Heavy Drizzle', + 56 => 'Light Freezing Drizzle', + 57 => 'Freezing Drizzle', + 61 => 'Light Rain', + 63 => 'Rain', + 65 => 'Heavy Rain', + 66 => 'Light Freezing Rain', + 67 => 'Freezing Rain', + 71 => 'Light Snow', + 73 => 'Snow', + 75 => 'Heavy Snow', + 77 => 'Snow Grains', + 80 => 'Light Showers', + 81 => 'Showers', + 82 => 'Heavy Showers', + 85 => 'Light Snow Showers', + 86 => 'Snow Showers', + 95 => 'Thunderstorm', + 96 => 'Light Thunderstorm with Hail', + 99 => 'Thunderstorm with Hail', + ]; + + public function __construct( + private HttpClientInterface $httpClient, + ) { + } + + /** + * @param float $latitude the latitude of the location + * @param float $longitude the longitude of the location + * + * @return array{ + * weather: string, + * time: string, + * temperature: string, + * wind_speed: string, + * } + */ + public function current(float $latitude, float $longitude): array + { + $response = $this->httpClient->request('GET', 'https://api.open-meteo.com/v1/forecast', [ + 'query' => [ + 'latitude' => $latitude, + 'longitude' => $longitude, + 'current' => 'weather_code,temperature_2m,wind_speed_10m', + ], + ]); + + $data = $response->toArray(); + + return [ + 'weather' => self::WMO_CODES[$data['current']['weather_code']] ?? 'Unknown', + 'time' => $data['current']['time'], + 'temperature' => $data['current']['temperature_2m'].$data['current_units']['temperature_2m'], + 'wind_speed' => $data['current']['wind_speed_10m'].$data['current_units']['wind_speed_10m'], + ]; + } + + /** + * @param float $latitude the latitude of the location + * @param float $longitude the longitude of the location + * @param int $days the number of days to forecast + * + * @return array{ + * weather: string, + * time: string, + * temperature_min: string, + * temperature_max: string, + * }[] + */ + public function forecast( + float $latitude, + float $longitude, + #[With(minimum: 1, maximum: 16)] + int $days = 7, + ): array { + $response = $this->httpClient->request('GET', 'https://api.open-meteo.com/v1/forecast', [ + 'query' => [ + 'latitude' => $latitude, + 'longitude' => $longitude, + 'daily' => 'weather_code,temperature_2m_max,temperature_2m_min', + 'forecast_days' => $days, + ], + ]); + + $data = $response->toArray(); + $forecast = []; + for ($i = 0; $i < $days; ++$i) { + $forecast[] = [ + 'weather' => self::WMO_CODES[$data['daily']['weather_code'][$i]] ?? 'Unknown', + 'time' => $data['daily']['time'][$i], + 'temperature_min' => $data['daily']['temperature_2m_min'][$i].$data['daily_units']['temperature_2m_min'], + 'temperature_max' => $data['daily']['temperature_2m_max'][$i].$data['daily_units']['temperature_2m_max'], + ]; + } + + return $forecast; + } +} diff --git a/src/agent/src/Toolbox/Tool/SerpApi.php b/src/agent/src/Toolbox/Tool/SerpApi.php new file mode 100644 index 000000000..782cd97c7 --- /dev/null +++ b/src/agent/src/Toolbox/Tool/SerpApi.php @@ -0,0 +1,51 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Toolbox\Tool; + +use Symfony\AI\Agent\Toolbox\Attribute\AsTool; +use Symfony\Contracts\HttpClient\HttpClientInterface; + +/** + * @author Christopher Hertel + */ +#[AsTool(name: 'serpapi', description: 'search for information on the internet')] +final readonly class SerpApi +{ + public function __construct( + private HttpClientInterface $httpClient, + private string $apiKey, + ) { + } + + /** + * @param string $query The search query to use + */ + public function __invoke(string $query): string + { + $response = $this->httpClient->request('GET', 'https://serpapi.com/search', [ + 'query' => [ + 'q' => $query, + 'api_key' => $this->apiKey, + ], + ]); + + return \sprintf('Results for "%s" are "%s".', $query, $this->extractBestResponse($response->toArray())); + } + + /** + * @param array $results + */ + private function extractBestResponse(array $results): string + { + return implode('. ', array_map(fn ($story) => $story['title'], $results['organic_results'])); + } +} diff --git a/src/agent/src/Toolbox/Tool/SimilaritySearch.php b/src/agent/src/Toolbox/Tool/SimilaritySearch.php new file mode 100644 index 000000000..05b0f2da9 --- /dev/null +++ b/src/agent/src/Toolbox/Tool/SimilaritySearch.php @@ -0,0 +1,59 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Toolbox\Tool; + +use Symfony\AI\Agent\Toolbox\Attribute\AsTool; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\PlatformInterface; +use Symfony\AI\Platform\Vector\Vector; +use Symfony\AI\Store\Document\VectorDocument; +use Symfony\AI\Store\VectorStoreInterface; + +/** + * @author Christopher Hertel + */ +#[AsTool('similarity_search', description: 'Searches for documents similar to a query or sentence.')] +final class SimilaritySearch +{ + /** + * @var VectorDocument[] + */ + public array $usedDocuments = []; + + public function __construct( + private readonly PlatformInterface $platform, + private readonly Model $model, + private readonly VectorStoreInterface $vectorStore, + ) { + } + + /** + * @param string $searchTerm string used for similarity search + */ + public function __invoke(string $searchTerm): string + { + /** @var Vector[] $vectors */ + $vectors = $this->platform->request($this->model, $searchTerm)->getContent(); + $this->usedDocuments = $this->vectorStore->query($vectors[0]); + + if (0 === \count($this->usedDocuments)) { + return 'No results found'; + } + + $result = 'Found documents with following information:'.\PHP_EOL; + foreach ($this->usedDocuments as $document) { + $result .= json_encode($document->metadata); + } + + return $result; + } +} diff --git a/src/agent/src/Toolbox/Tool/Tavily.php b/src/agent/src/Toolbox/Tool/Tavily.php new file mode 100644 index 000000000..f84848d7a --- /dev/null +++ b/src/agent/src/Toolbox/Tool/Tavily.php @@ -0,0 +1,65 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Toolbox\Tool; + +use Symfony\AI\Agent\Toolbox\Attribute\AsTool; +use Symfony\Contracts\HttpClient\HttpClientInterface; + +/** + * Tool integration of tavily.com. + * + * @author Christopher Hertel + */ +#[AsTool('tavily_search', description: 'search for information on the internet', method: 'search')] +#[AsTool('tavily_extract', description: 'fetch content from websites', method: 'extract')] +final readonly class Tavily +{ + /** + * @param array $options + */ + public function __construct( + private HttpClientInterface $httpClient, + private string $apiKey, + private array $options = ['include_images' => false], + ) { + } + + /** + * @param string $query The search query to use + */ + public function search(string $query): string + { + $response = $this->httpClient->request('POST', 'https://api.tavily.com/search', [ + 'json' => array_merge($this->options, [ + 'query' => $query, + 'api_key' => $this->apiKey, + ]), + ]); + + return $response->getContent(); + } + + /** + * @param string[] $urls URLs to fetch information from + */ + public function extract(array $urls): string + { + $response = $this->httpClient->request('POST', 'https://api.tavily.com/extract', [ + 'json' => [ + 'urls' => $urls, + 'api_key' => $this->apiKey, + ], + ]); + + return $response->getContent(); + } +} diff --git a/src/agent/src/Toolbox/Tool/Wikipedia.php b/src/agent/src/Toolbox/Tool/Wikipedia.php new file mode 100644 index 000000000..9aa527373 --- /dev/null +++ b/src/agent/src/Toolbox/Tool/Wikipedia.php @@ -0,0 +1,99 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Toolbox\Tool; + +use Symfony\AI\Agent\Toolbox\Attribute\AsTool; +use Symfony\Contracts\HttpClient\HttpClientInterface; + +/** + * @author Christopher Hertel + */ +#[AsTool('wikipedia_search', description: 'Searches Wikipedia for a given query', method: 'search')] +#[AsTool('wikipedia_article', description: 'Retrieves a Wikipedia article by its title', method: 'article')] +final readonly class Wikipedia +{ + public function __construct( + private HttpClientInterface $httpClient, + private string $locale = 'en', + ) { + } + + /** + * @param string $query The query to search for on Wikipedia + */ + public function search(string $query): string + { + $result = $this->execute([ + 'action' => 'query', + 'format' => 'json', + 'list' => 'search', + 'srsearch' => $query, + ], $this->locale); + + $titles = array_map(fn (array $item) => $item['title'], $result['query']['search']); + + if (empty($titles)) { + return 'No articles were found on Wikipedia.'; + } + + $response = 'Articles with the following titles were found on Wikipedia:'.\PHP_EOL; + foreach ($titles as $title) { + $response .= ' - '.$title.\PHP_EOL; + } + + return $response.\PHP_EOL.'Use the title of the article with tool "wikipedia_article" to load the content.'; + } + + /** + * @param string $title The title of the article to load from Wikipedia + */ + public function article(string $title): string + { + $result = $this->execute([ + 'action' => 'query', + 'format' => 'json', + 'prop' => 'extracts|info|pageimages', + 'titles' => $title, + 'explaintext' => true, + 'redirects' => true, + ], $this->locale); + + $article = current($result['query']['pages']); + + if (\array_key_exists('missing', $article)) { + return \sprintf('No article with title "%s" was found on Wikipedia.', $title); + } + + $response = ''; + if (\array_key_exists('redirects', $result['query'])) { + foreach ($result['query']['redirects'] as $redirect) { + $response .= \sprintf('The article "%s" redirects to article "%s".', $redirect['from'], $redirect['to']).\PHP_EOL; + } + $response .= \PHP_EOL; + } + + return $response.'This is the content of article "'.$article['title'].'":'.\PHP_EOL.$article['extract']; + } + + /** + * @param array $query + * + * @return array + */ + private function execute(array $query, ?string $locale = null): array + { + $url = \sprintf('https://%s.wikipedia.org/w/api.php', $locale ?? $this->locale); + $response = $this->httpClient->request('GET', $url, ['query' => $query]); + + return $response->toArray(); + } +} diff --git a/src/agent/src/Toolbox/Tool/YouTubeTranscriber.php b/src/agent/src/Toolbox/Tool/YouTubeTranscriber.php new file mode 100644 index 000000000..90c1881ed --- /dev/null +++ b/src/agent/src/Toolbox/Tool/YouTubeTranscriber.php @@ -0,0 +1,79 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Toolbox\Tool; + +use Symfony\AI\Agent\Exception\LogicException; +use Symfony\AI\Agent\Exception\RuntimeException; +use Symfony\AI\Agent\Toolbox\Attribute\AsTool; +use Symfony\Component\CssSelector\CssSelectorConverter; +use Symfony\Component\DomCrawler\Crawler; +use Symfony\Contracts\HttpClient\HttpClientInterface; + +/** + * @author Christopher Hertel + */ +#[AsTool('youtube_transcript', 'Fetches the transcript of a YouTube video')] +final readonly class YouTubeTranscriber +{ + public function __construct( + private HttpClientInterface $client, + ) { + if (!class_exists(Crawler::class)) { + throw new LogicException('The Symfony DomCrawler component is required to use this tool. Try running "composer require symfony/dom-crawler".'); + } + if (!class_exists(CssSelectorConverter::class)) { + throw new LogicException('The Symfony CSS Selector component is required to use this tool. Try running "composer require symfony/css-selector".'); + } + } + + /** + * @param string $videoId The ID of the YouTube video + */ + public function __invoke(string $videoId): string + { + // Fetch the HTML content of the YouTube video page + $htmlResponse = $this->client->request('GET', 'https://youtube.com/watch?v='.$videoId); + $html = $htmlResponse->getContent(); + + // Use DomCrawler to parse the HTML + $crawler = new Crawler($html); + + // Extract the script containing the ytInitialPlayerResponse + $scriptContent = $crawler->filter('script')->reduce(function (Crawler $node) { + return str_contains($node->text(), 'var ytInitialPlayerResponse = {'); + })->text(); + + // Extract and parse the JSON data from the script + $start = strpos($scriptContent, 'var ytInitialPlayerResponse = ') + \strlen('var ytInitialPlayerResponse = '); + $dataString = substr($scriptContent, $start); + $dataString = substr($dataString, 0, strrpos($dataString, ';') ?: null); + $data = json_decode(trim($dataString), true); + + // Extract the URL for the captions + if (!isset($data['captions']['playerCaptionsTracklistRenderer']['captionTracks'][0]['baseUrl'])) { + throw new RuntimeException('Captions are not available for this video.'); + } + $captionsUrl = $data['captions']['playerCaptionsTracklistRenderer']['captionTracks'][0]['baseUrl']; + + // Fetch and parse the captions XML + $xmlResponse = $this->client->request('GET', $captionsUrl); + $xmlContent = $xmlResponse->getContent(); + $xmlCrawler = new Crawler($xmlContent); + + // Collect all text elements from the captions + $transcript = $xmlCrawler->filter('text')->each(function (Crawler $node) { + return $node->text().' '; + }); + + return implode(\PHP_EOL, $transcript); + } +} diff --git a/src/agent/src/Toolbox/ToolCallResult.php b/src/agent/src/Toolbox/ToolCallResult.php new file mode 100644 index 000000000..153631c29 --- /dev/null +++ b/src/agent/src/Toolbox/ToolCallResult.php @@ -0,0 +1,26 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Toolbox; + +use Symfony\AI\Platform\Response\ToolCall; + +/** + * @author Christopher Hertel + */ +final readonly class ToolCallResult +{ + public function __construct( + public ToolCall $toolCall, + public mixed $result, + ) { + } +} diff --git a/src/agent/src/Toolbox/ToolFactory/AbstractToolFactory.php b/src/agent/src/Toolbox/ToolFactory/AbstractToolFactory.php new file mode 100644 index 000000000..4a3a146bd --- /dev/null +++ b/src/agent/src/Toolbox/ToolFactory/AbstractToolFactory.php @@ -0,0 +1,44 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Toolbox\ToolFactory; + +use Symfony\AI\Agent\Toolbox\Attribute\AsTool; +use Symfony\AI\Agent\Toolbox\Exception\ToolConfigurationException; +use Symfony\AI\Agent\Toolbox\ToolFactoryInterface; +use Symfony\AI\Platform\Contract\JsonSchema\Factory; +use Symfony\AI\Platform\Tool\ExecutionReference; +use Symfony\AI\Platform\Tool\Tool; + +/** + * @author Christopher Hertel + */ +abstract class AbstractToolFactory implements ToolFactoryInterface +{ + public function __construct( + private readonly Factory $factory = new Factory(), + ) { + } + + protected function convertAttribute(string $className, AsTool $attribute): Tool + { + try { + return new Tool( + new ExecutionReference($className, $attribute->method), + $attribute->name, + $attribute->description, + $this->factory->buildParameters($className, $attribute->method) + ); + } catch (\ReflectionException $e) { + throw ToolConfigurationException::invalidMethod($className, $attribute->method, $e); + } + } +} diff --git a/src/agent/src/Toolbox/ToolFactory/ChainFactory.php b/src/agent/src/Toolbox/ToolFactory/ChainFactory.php new file mode 100644 index 000000000..5fca2bc2a --- /dev/null +++ b/src/agent/src/Toolbox/ToolFactory/ChainFactory.php @@ -0,0 +1,54 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Toolbox\ToolFactory; + +use Symfony\AI\Agent\Toolbox\Exception\ToolException; +use Symfony\AI\Agent\Toolbox\ToolFactoryInterface; + +/** + * @author Christopher Hertel + */ +final readonly class ChainFactory implements ToolFactoryInterface +{ + /** + * @var list + */ + private array $factories; + + /** + * @param iterable $factories + */ + public function __construct(iterable $factories) + { + $this->factories = $factories instanceof \Traversable ? iterator_to_array($factories) : $factories; + } + + public function getTool(string $reference): iterable + { + $invalid = 0; + foreach ($this->factories as $factory) { + try { + yield from $factory->getTool($reference); + } catch (ToolException) { + ++$invalid; + continue; + } + + // If the factory does not throw an exception, we don't need to check the others + return; + } + + if ($invalid === \count($this->factories)) { + throw ToolException::invalidReference($reference); + } + } +} diff --git a/src/agent/src/Toolbox/ToolFactory/MemoryToolFactory.php b/src/agent/src/Toolbox/ToolFactory/MemoryToolFactory.php new file mode 100644 index 000000000..80846d96d --- /dev/null +++ b/src/agent/src/Toolbox/ToolFactory/MemoryToolFactory.php @@ -0,0 +1,48 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Toolbox\ToolFactory; + +use Symfony\AI\Agent\Toolbox\Attribute\AsTool; +use Symfony\AI\Agent\Toolbox\Exception\ToolException; + +/** + * @author Christopher Hertel + */ +final class MemoryToolFactory extends AbstractToolFactory +{ + /** + * @var array + */ + private array $tools = []; + + public function addTool(string|object $class, string $name, string $description, string $method = '__invoke'): self + { + $className = \is_object($class) ? $class::class : $class; + $this->tools[$className][] = new AsTool($name, $description, $method); + + return $this; + } + + /** + * @param class-string $reference + */ + public function getTool(string $reference): iterable + { + if (!isset($this->tools[$reference])) { + throw ToolException::invalidReference($reference); + } + + foreach ($this->tools[$reference] as $tool) { + yield $this->convertAttribute($reference, $tool); + } + } +} diff --git a/src/agent/src/Toolbox/ToolFactory/ReflectionToolFactory.php b/src/agent/src/Toolbox/ToolFactory/ReflectionToolFactory.php new file mode 100644 index 000000000..8e76634e7 --- /dev/null +++ b/src/agent/src/Toolbox/ToolFactory/ReflectionToolFactory.php @@ -0,0 +1,44 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Toolbox\ToolFactory; + +use Symfony\AI\Agent\Toolbox\Attribute\AsTool; +use Symfony\AI\Agent\Toolbox\Exception\ToolException; + +/** + * Metadata factory that uses reflection in combination with `#[AsTool]` attribute to extract metadata from tools. + * + * @author Christopher Hertel + */ +final class ReflectionToolFactory extends AbstractToolFactory +{ + /** + * @param class-string $reference + */ + public function getTool(string $reference): iterable + { + if (!class_exists($reference)) { + throw ToolException::invalidReference($reference); + } + + $reflectionClass = new \ReflectionClass($reference); + $attributes = $reflectionClass->getAttributes(AsTool::class); + + if (0 === \count($attributes)) { + throw ToolException::missingAttribute($reference); + } + + foreach ($attributes as $attribute) { + yield $this->convertAttribute($reference, $attribute->newInstance()); + } + } +} diff --git a/src/agent/src/Toolbox/ToolFactoryInterface.php b/src/agent/src/Toolbox/ToolFactoryInterface.php new file mode 100644 index 000000000..aafdf65d7 --- /dev/null +++ b/src/agent/src/Toolbox/ToolFactoryInterface.php @@ -0,0 +1,28 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Toolbox; + +use Symfony\AI\Agent\Toolbox\Exception\ToolException; +use Symfony\AI\Platform\Tool\Tool; + +/** + * @author Christopher Hertel + */ +interface ToolFactoryInterface +{ + /** + * @return iterable + * + * @throws ToolException if the metadata for the given reference is not found + */ + public function getTool(string $reference): iterable; +} diff --git a/src/agent/src/Toolbox/ToolResultConverter.php b/src/agent/src/Toolbox/ToolResultConverter.php new file mode 100644 index 000000000..f9e656c0e --- /dev/null +++ b/src/agent/src/Toolbox/ToolResultConverter.php @@ -0,0 +1,42 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Toolbox; + +/** + * @author Christopher Hertel + */ +final readonly class ToolResultConverter +{ + /** + * @param \JsonSerializable|\Stringable|array|float|string|null $result + */ + public function convert(\JsonSerializable|\Stringable|array|float|string|\DateTimeInterface|null $result): ?string + { + if (null === $result) { + return null; + } + + if ($result instanceof \JsonSerializable || \is_array($result)) { + return json_encode($result, flags: \JSON_THROW_ON_ERROR); + } + + if (\is_float($result) || $result instanceof \Stringable) { + return (string) $result; + } + + if ($result instanceof \DateTimeInterface) { + return $result->format(\DATE_ATOM); + } + + return $result; + } +} diff --git a/src/agent/src/Toolbox/Toolbox.php b/src/agent/src/Toolbox/Toolbox.php new file mode 100644 index 000000000..bcab2c0e0 --- /dev/null +++ b/src/agent/src/Toolbox/Toolbox.php @@ -0,0 +1,110 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Toolbox; + +use Psr\Log\LoggerInterface; +use Psr\Log\NullLogger; +use Symfony\AI\Agent\Toolbox\Exception\ToolExecutionException; +use Symfony\AI\Agent\Toolbox\Exception\ToolNotFoundException; +use Symfony\AI\Agent\Toolbox\ToolFactory\ReflectionToolFactory; +use Symfony\AI\Platform\Response\ToolCall; +use Symfony\AI\Platform\Tool\Tool; + +/** + * @author Christopher Hertel + */ +final class Toolbox implements ToolboxInterface +{ + /** + * List of executable tools. + * + * @var list + */ + private readonly array $tools; + + /** + * List of tool metadata objects. + * + * @var Tool[] + */ + private array $map; + + /** + * @param iterable $tools + */ + public function __construct( + private readonly ToolFactoryInterface $toolFactory, + iterable $tools, + private readonly LoggerInterface $logger = new NullLogger(), + ) { + $this->tools = $tools instanceof \Traversable ? iterator_to_array($tools) : $tools; + } + + public static function create(object ...$tools): self + { + return new self(new ReflectionToolFactory(), $tools); + } + + public function getTools(): array + { + if (isset($this->map)) { + return $this->map; + } + + $map = []; + foreach ($this->tools as $tool) { + foreach ($this->toolFactory->getTool($tool::class) as $metadata) { + $map[] = $metadata; + } + } + + return $this->map = $map; + } + + public function execute(ToolCall $toolCall): mixed + { + $metadata = $this->getMetadata($toolCall); + $tool = $this->getExecutable($metadata); + + try { + $this->logger->debug(\sprintf('Executing tool "%s".', $toolCall->name), $toolCall->arguments); + $result = $tool->{$metadata->reference->method}(...$toolCall->arguments); + } catch (\Throwable $e) { + $this->logger->warning(\sprintf('Failed to execute tool "%s".', $toolCall->name), ['exception' => $e]); + throw ToolExecutionException::executionFailed($toolCall, $e); + } + + return $result; + } + + private function getMetadata(ToolCall $toolCall): Tool + { + foreach ($this->getTools() as $metadata) { + if ($metadata->name === $toolCall->name) { + return $metadata; + } + } + + throw ToolNotFoundException::notFoundForToolCall($toolCall); + } + + private function getExecutable(Tool $metadata): object + { + foreach ($this->tools as $tool) { + if ($tool instanceof $metadata->reference->class) { + return $tool; + } + } + + throw ToolNotFoundException::notFoundForReference($metadata->reference); + } +} diff --git a/src/agent/src/Toolbox/ToolboxInterface.php b/src/agent/src/Toolbox/ToolboxInterface.php new file mode 100644 index 000000000..15478644a --- /dev/null +++ b/src/agent/src/Toolbox/ToolboxInterface.php @@ -0,0 +1,34 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Toolbox; + +use Symfony\AI\Agent\Toolbox\Exception\ToolExecutionException; +use Symfony\AI\Agent\Toolbox\Exception\ToolNotFoundException; +use Symfony\AI\Platform\Response\ToolCall; +use Symfony\AI\Platform\Tool\Tool; + +/** + * @author Christopher Hertel + */ +interface ToolboxInterface +{ + /** + * @return Tool[] + */ + public function getTools(): array; + + /** + * @throws ToolExecutionException if the tool execution fails + * @throws ToolNotFoundException if the tool is not found + */ + public function execute(ToolCall $toolCall): mixed; +} diff --git a/src/agent/tests/InputProcessor/ModelOverrideInputProcessorTest.php b/src/agent/tests/InputProcessor/ModelOverrideInputProcessorTest.php new file mode 100644 index 000000000..5acaddf90 --- /dev/null +++ b/src/agent/tests/InputProcessor/ModelOverrideInputProcessorTest.php @@ -0,0 +1,74 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Tests\InputProcessor; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Agent\Exception\InvalidArgumentException; +use Symfony\AI\Agent\Input; +use Symfony\AI\Agent\InputProcessor\ModelOverrideInputProcessor; +use Symfony\AI\Platform\Bridge\Anthropic\Claude; +use Symfony\AI\Platform\Bridge\OpenAI\Embeddings; +use Symfony\AI\Platform\Bridge\OpenAI\GPT; +use Symfony\AI\Platform\Message\MessageBag; + +#[CoversClass(ModelOverrideInputProcessor::class)] +#[UsesClass(GPT::class)] +#[UsesClass(Claude::class)] +#[UsesClass(Input::class)] +#[UsesClass(MessageBag::class)] +#[UsesClass(Embeddings::class)] +#[Small] +final class ModelOverrideInputProcessorTest extends TestCase +{ + #[Test] + public function processInputWithValidModelOption(): void + { + $gpt = new GPT(); + $claude = new Claude(); + $input = new Input($gpt, new MessageBag(), ['model' => $claude]); + + $processor = new ModelOverrideInputProcessor(); + $processor->processInput($input); + + self::assertSame($claude, $input->model); + } + + #[Test] + public function processInputWithoutModelOption(): void + { + $gpt = new GPT(); + $input = new Input($gpt, new MessageBag(), []); + + $processor = new ModelOverrideInputProcessor(); + $processor->processInput($input); + + self::assertSame($gpt, $input->model); + } + + #[Test] + public function processInputWithInvalidModelOption(): void + { + self::expectException(InvalidArgumentException::class); + self::expectExceptionMessage('Option "model" must be an instance of Symfony\AI\Platform\Model.'); + + $gpt = new GPT(); + $model = new MessageBag(); + $input = new Input($gpt, new MessageBag(), ['model' => $model]); + + $processor = new ModelOverrideInputProcessor(); + $processor->processInput($input); + } +} diff --git a/src/agent/tests/InputProcessor/SystemPromptInputProcessorTest.php b/src/agent/tests/InputProcessor/SystemPromptInputProcessorTest.php new file mode 100644 index 000000000..77abe0c6c --- /dev/null +++ b/src/agent/tests/InputProcessor/SystemPromptInputProcessorTest.php @@ -0,0 +1,195 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Tests\InputProcessor; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Agent\Input; +use Symfony\AI\Agent\InputProcessor\SystemPromptInputProcessor; +use Symfony\AI\Agent\Tests\Fixture\Tool\ToolNoParams; +use Symfony\AI\Agent\Tests\Fixture\Tool\ToolRequiredParams; +use Symfony\AI\Agent\Toolbox\ToolboxInterface; +use Symfony\AI\Platform\Bridge\OpenAI\GPT; +use Symfony\AI\Platform\Message\Content\Text; +use Symfony\AI\Platform\Message\Message; +use Symfony\AI\Platform\Message\MessageBag; +use Symfony\AI\Platform\Message\SystemMessage; +use Symfony\AI\Platform\Message\UserMessage; +use Symfony\AI\Platform\Response\ToolCall; +use Symfony\AI\Platform\Tool\ExecutionReference; +use Symfony\AI\Platform\Tool\Tool; + +#[CoversClass(SystemPromptInputProcessor::class)] +#[UsesClass(GPT::class)] +#[UsesClass(Message::class)] +#[UsesClass(MessageBag::class)] +#[UsesClass(Input::class)] +#[UsesClass(SystemMessage::class)] +#[UsesClass(UserMessage::class)] +#[UsesClass(Text::class)] +#[UsesClass(Tool::class)] +#[UsesClass(ExecutionReference::class)] +#[Small] +final class SystemPromptInputProcessorTest extends TestCase +{ + #[Test] + public function processInputAddsSystemMessageWhenNoneExists(): void + { + $processor = new SystemPromptInputProcessor('This is a system prompt'); + + $input = new Input(new GPT(), new MessageBag(Message::ofUser('This is a user message')), []); + $processor->processInput($input); + + $messages = $input->messages->getMessages(); + self::assertCount(2, $messages); + self::assertInstanceOf(SystemMessage::class, $messages[0]); + self::assertInstanceOf(UserMessage::class, $messages[1]); + self::assertSame('This is a system prompt', $messages[0]->content); + } + + #[Test] + public function processInputDoesNotAddSystemMessageWhenOneExists(): void + { + $processor = new SystemPromptInputProcessor('This is a system prompt'); + + $messages = new MessageBag( + Message::forSystem('This is already a system prompt'), + Message::ofUser('This is a user message'), + ); + $input = new Input(new GPT(), $messages, []); + $processor->processInput($input); + + $messages = $input->messages->getMessages(); + self::assertCount(2, $messages); + self::assertInstanceOf(SystemMessage::class, $messages[0]); + self::assertInstanceOf(UserMessage::class, $messages[1]); + self::assertSame('This is already a system prompt', $messages[0]->content); + } + + #[Test] + public function doesNotIncludeToolsIfToolboxIsEmpty(): void + { + $processor = new SystemPromptInputProcessor( + 'This is a system prompt', + new class implements ToolboxInterface { + public function getTools(): array + { + return []; + } + + public function execute(ToolCall $toolCall): mixed + { + return null; + } + } + ); + + $input = new Input(new GPT(), new MessageBag(Message::ofUser('This is a user message')), []); + $processor->processInput($input); + + $messages = $input->messages->getMessages(); + self::assertCount(2, $messages); + self::assertInstanceOf(SystemMessage::class, $messages[0]); + self::assertInstanceOf(UserMessage::class, $messages[1]); + self::assertSame('This is a system prompt', $messages[0]->content); + } + + #[Test] + public function includeToolDefinitions(): void + { + $processor = new SystemPromptInputProcessor( + 'This is a system prompt', + new class implements ToolboxInterface { + public function getTools(): array + { + return [ + new Tool(new ExecutionReference(ToolNoParams::class), 'tool_no_params', 'A tool without parameters', null), + new Tool( + new ExecutionReference(ToolRequiredParams::class, 'bar'), + 'tool_required_params', + <<processInput($input); + + $messages = $input->messages->getMessages(); + self::assertCount(2, $messages); + self::assertInstanceOf(SystemMessage::class, $messages[0]); + self::assertInstanceOf(UserMessage::class, $messages[1]); + self::assertSame(<<content); + } + + #[Test] + public function withStringableSystemPrompt(): void + { + $processor = new SystemPromptInputProcessor( + new SystemPromptService(), + new class implements ToolboxInterface { + public function getTools(): array + { + return [ + new Tool(new ExecutionReference(ToolNoParams::class), 'tool_no_params', 'A tool without parameters', null), + ]; + } + + public function execute(ToolCall $toolCall): mixed + { + return null; + } + } + ); + + $input = new Input(new GPT(), new MessageBag(Message::ofUser('This is a user message')), []); + $processor->processInput($input); + + $messages = $input->messages->getMessages(); + self::assertCount(2, $messages); + self::assertInstanceOf(SystemMessage::class, $messages[0]); + self::assertInstanceOf(UserMessage::class, $messages[1]); + self::assertSame(<<content); + } +} diff --git a/src/agent/tests/InputProcessor/SystemPromptService.php b/src/agent/tests/InputProcessor/SystemPromptService.php new file mode 100644 index 000000000..1fc0e4945 --- /dev/null +++ b/src/agent/tests/InputProcessor/SystemPromptService.php @@ -0,0 +1,20 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Tests\InputProcessor; + +final class SystemPromptService implements \Stringable +{ + public function __toString(): string + { + return 'My dynamic system prompt.'; + } +} diff --git a/src/agent/tests/StructuredOutput/ChainProcessorTest.php b/src/agent/tests/StructuredOutput/ChainProcessorTest.php new file mode 100644 index 000000000..b29d89992 --- /dev/null +++ b/src/agent/tests/StructuredOutput/ChainProcessorTest.php @@ -0,0 +1,173 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Tests\StructuredOutput; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Agent\Exception\MissingModelSupportException; +use Symfony\AI\Agent\Input; +use Symfony\AI\Agent\Output; +use Symfony\AI\Agent\StructuredOutput\AgentProcessor; +use Symfony\AI\Agent\Tests\Fixture\SomeStructure; +use Symfony\AI\Agent\Tests\Fixture\StructuredOutput\MathReasoning; +use Symfony\AI\Agent\Tests\Fixture\StructuredOutput\Step; +use Symfony\AI\Platform\Capability; +use Symfony\AI\Platform\Message\MessageBag; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\Response\Choice; +use Symfony\AI\Platform\Response\ObjectResponse; +use Symfony\AI\Platform\Response\TextResponse; +use Symfony\Component\Serializer\SerializerInterface; + +#[CoversClass(AgentProcessor::class)] +#[UsesClass(Input::class)] +#[UsesClass(Output::class)] +#[UsesClass(MessageBag::class)] +#[UsesClass(Choice::class)] +#[UsesClass(MissingModelSupportException::class)] +#[UsesClass(TextResponse::class)] +#[UsesClass(ObjectResponse::class)] +#[UsesClass(Model::class)] +final class ChainProcessorTest extends TestCase +{ + #[Test] + public function processInputWithOutputStructure(): void + { + $agentProcessor = new AgentProcessor(new ConfigurableResponseFormatFactory(['some' => 'format'])); + + $model = new Model('gpt-4', [Capability::OUTPUT_STRUCTURED]); + $input = new Input($model, new MessageBag(), ['output_structure' => 'SomeStructure']); + + $agentProcessor->processInput($input); + + self::assertSame(['response_format' => ['some' => 'format']], $input->getOptions()); + } + + #[Test] + public function processInputWithoutOutputStructure(): void + { + $agentProcessor = new AgentProcessor(new ConfigurableResponseFormatFactory()); + + $model = new Model('gpt-4', [Capability::OUTPUT_STRUCTURED]); + $input = new Input($model, new MessageBag(), []); + + $agentProcessor->processInput($input); + + self::assertSame([], $input->getOptions()); + } + + #[Test] + public function processInputThrowsExceptionWhenLlmDoesNotSupportStructuredOutput(): void + { + self::expectException(MissingModelSupportException::class); + + $agentProcessor = new AgentProcessor(new ConfigurableResponseFormatFactory()); + + $model = new Model('gpt-3'); + $input = new Input($model, new MessageBag(), ['output_structure' => 'SomeStructure']); + + $agentProcessor->processInput($input); + } + + #[Test] + public function processOutputWithResponseFormat(): void + { + $agentProcessor = new AgentProcessor(new ConfigurableResponseFormatFactory(['some' => 'format'])); + + $model = new Model('gpt-4', [Capability::OUTPUT_STRUCTURED]); + $options = ['output_structure' => SomeStructure::class]; + $input = new Input($model, new MessageBag(), $options); + $agentProcessor->processInput($input); + + $response = new TextResponse('{"some": "data"}'); + + $output = new Output($model, $response, new MessageBag(), $input->getOptions()); + + $agentProcessor->processOutput($output); + + self::assertInstanceOf(ObjectResponse::class, $output->response); + self::assertInstanceOf(SomeStructure::class, $output->response->getContent()); + self::assertSame('data', $output->response->getContent()->some); + } + + #[Test] + public function processOutputWithComplexResponseFormat(): void + { + $agentProcessor = new AgentProcessor(new ConfigurableResponseFormatFactory(['some' => 'format'])); + + $model = new Model('gpt-4', [Capability::OUTPUT_STRUCTURED]); + $options = ['output_structure' => MathReasoning::class]; + $input = new Input($model, new MessageBag(), $options); + $agentProcessor->processInput($input); + + $response = new TextResponse(<<getOptions()); + + $agentProcessor->processOutput($output); + + self::assertInstanceOf(ObjectResponse::class, $output->response); + self::assertInstanceOf(MathReasoning::class, $structure = $output->response->getContent()); + self::assertCount(5, $structure->steps); + self::assertInstanceOf(Step::class, $structure->steps[0]); + self::assertInstanceOf(Step::class, $structure->steps[1]); + self::assertInstanceOf(Step::class, $structure->steps[2]); + self::assertInstanceOf(Step::class, $structure->steps[3]); + self::assertInstanceOf(Step::class, $structure->steps[4]); + self::assertSame('x = -3.75', $structure->finalAnswer); + } + + #[Test] + public function processOutputWithoutResponseFormat(): void + { + $responseFormatFactory = new ConfigurableResponseFormatFactory(); + $serializer = self::createMock(SerializerInterface::class); + $agentProcessor = new AgentProcessor($responseFormatFactory, $serializer); + + $model = self::createMock(Model::class); + $response = new TextResponse(''); + + $output = new Output($model, $response, new MessageBag(), []); + + $agentProcessor->processOutput($output); + + self::assertSame($response, $output->response); + } +} diff --git a/src/agent/tests/StructuredOutput/ConfigurableResponseFormatFactory.php b/src/agent/tests/StructuredOutput/ConfigurableResponseFormatFactory.php new file mode 100644 index 000000000..74a0bdbf8 --- /dev/null +++ b/src/agent/tests/StructuredOutput/ConfigurableResponseFormatFactory.php @@ -0,0 +1,30 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Tests\StructuredOutput; + +use Symfony\AI\Agent\StructuredOutput\ResponseFormatFactoryInterface; + +final readonly class ConfigurableResponseFormatFactory implements ResponseFormatFactoryInterface +{ + /** + * @param array $responseFormat + */ + public function __construct( + private array $responseFormat = [], + ) { + } + + public function create(string $responseClass): array + { + return $this->responseFormat; + } +} diff --git a/src/agent/tests/StructuredOutput/ResponseFormatFactoryTest.php b/src/agent/tests/StructuredOutput/ResponseFormatFactoryTest.php new file mode 100644 index 000000000..90af8a84f --- /dev/null +++ b/src/agent/tests/StructuredOutput/ResponseFormatFactoryTest.php @@ -0,0 +1,57 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Tests\StructuredOutput; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Agent\StructuredOutput\ResponseFormatFactory; +use Symfony\AI\Agent\Tests\Fixture\StructuredOutput\User; +use Symfony\AI\Platform\Contract\JsonSchema\DescriptionParser; +use Symfony\AI\Platform\Contract\JsonSchema\Factory; + +#[CoversClass(ResponseFormatFactory::class)] +#[UsesClass(DescriptionParser::class)] +#[UsesClass(Factory::class)] +final class ResponseFormatFactoryTest extends TestCase +{ + #[Test] + public function create(): void + { + self::assertSame([ + 'type' => 'json_schema', + 'json_schema' => [ + 'name' => 'User', + 'schema' => [ + 'type' => 'object', + 'properties' => [ + 'id' => ['type' => 'integer'], + 'name' => [ + 'type' => 'string', + 'description' => 'The name of the user in lowercase', + ], + 'createdAt' => [ + 'type' => 'string', + 'format' => 'date-time', + ], + 'isActive' => ['type' => 'boolean'], + 'age' => ['type' => ['integer', 'null']], + ], + 'required' => ['id', 'name', 'createdAt', 'isActive'], + 'additionalProperties' => false, + ], + 'strict' => true, + ], + ], (new ResponseFormatFactory())->create(User::class)); + } +} diff --git a/src/agent/tests/Toolbox/Attribute/AsToolTest.php b/src/agent/tests/Toolbox/Attribute/AsToolTest.php new file mode 100644 index 000000000..8d54a18ff --- /dev/null +++ b/src/agent/tests/Toolbox/Attribute/AsToolTest.php @@ -0,0 +1,33 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Tests\Toolbox\Attribute; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Agent\Toolbox\Attribute\AsTool; + +#[CoversClass(AsTool::class)] +final class AsToolTest extends TestCase +{ + #[Test] + public function canBeConstructed(): void + { + $attribute = new AsTool( + name: 'name', + description: 'description', + ); + + self::assertSame('name', $attribute->name); + self::assertSame('description', $attribute->description); + } +} diff --git a/src/agent/tests/Toolbox/ChainProcessorTest.php b/src/agent/tests/Toolbox/ChainProcessorTest.php new file mode 100644 index 000000000..f3224d06a --- /dev/null +++ b/src/agent/tests/Toolbox/ChainProcessorTest.php @@ -0,0 +1,97 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Tests\Toolbox; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Agent\Exception\MissingModelSupportException; +use Symfony\AI\Agent\Input; +use Symfony\AI\Agent\Toolbox\AgentProcessor; +use Symfony\AI\Agent\Toolbox\ToolboxInterface; +use Symfony\AI\Platform\Capability; +use Symfony\AI\Platform\Message\MessageBag; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\Tool\ExecutionReference; +use Symfony\AI\Platform\Tool\Tool; + +#[CoversClass(AgentProcessor::class)] +#[UsesClass(Input::class)] +#[UsesClass(Tool::class)] +#[UsesClass(ExecutionReference::class)] +#[UsesClass(MessageBag::class)] +#[UsesClass(MissingModelSupportException::class)] +#[UsesClass(Model::class)] +class ChainProcessorTest extends TestCase +{ + #[Test] + public function processInputWithoutRegisteredToolsWillResultInNoOptionChange(): void + { + $toolbox = $this->createStub(ToolboxInterface::class); + $toolbox->method('getTools')->willReturn([]); + + $model = new Model('gpt-4', [Capability::TOOL_CALLING]); + $agentProcessor = new AgentProcessor($toolbox); + $input = new Input($model, new MessageBag(), []); + + $agentProcessor->processInput($input); + + self::assertSame([], $input->getOptions()); + } + + #[Test] + public function processInputWithRegisteredToolsWillResultInOptionChange(): void + { + $toolbox = $this->createStub(ToolboxInterface::class); + $tool1 = new Tool(new ExecutionReference('ClassTool1', 'method1'), 'tool1', 'description1', null); + $tool2 = new Tool(new ExecutionReference('ClassTool2', 'method1'), 'tool2', 'description2', null); + $toolbox->method('getTools')->willReturn([$tool1, $tool2]); + + $model = new Model('gpt-4', [Capability::TOOL_CALLING]); + $agentProcessor = new AgentProcessor($toolbox); + $input = new Input($model, new MessageBag(), []); + + $agentProcessor->processInput($input); + + self::assertSame(['tools' => [$tool1, $tool2]], $input->getOptions()); + } + + #[Test] + public function processInputWithRegisteredToolsButToolOverride(): void + { + $toolbox = $this->createStub(ToolboxInterface::class); + $tool1 = new Tool(new ExecutionReference('ClassTool1', 'method1'), 'tool1', 'description1', null); + $tool2 = new Tool(new ExecutionReference('ClassTool2', 'method1'), 'tool2', 'description2', null); + $toolbox->method('getTools')->willReturn([$tool1, $tool2]); + + $model = new Model('gpt-4', [Capability::TOOL_CALLING]); + $agentProcessor = new AgentProcessor($toolbox); + $input = new Input($model, new MessageBag(), ['tools' => ['tool2']]); + + $agentProcessor->processInput($input); + + self::assertSame(['tools' => [$tool2]], $input->getOptions()); + } + + #[Test] + public function processInputWithUnsupportedToolCallingWillThrowException(): void + { + self::expectException(MissingModelSupportException::class); + + $model = new Model('gpt-3'); + $agentProcessor = new AgentProcessor($this->createStub(ToolboxInterface::class)); + $input = new Input($model, new MessageBag(), []); + + $agentProcessor->processInput($input); + } +} diff --git a/src/agent/tests/Toolbox/FaultTolerantToolboxTest.php b/src/agent/tests/Toolbox/FaultTolerantToolboxTest.php new file mode 100644 index 000000000..53bb47754 --- /dev/null +++ b/src/agent/tests/Toolbox/FaultTolerantToolboxTest.php @@ -0,0 +1,92 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Tests\Toolbox; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Agent\Tests\Fixture\Tool\ToolNoParams; +use Symfony\AI\Agent\Tests\Fixture\Tool\ToolRequiredParams; +use Symfony\AI\Agent\Toolbox\Exception\ToolExecutionException; +use Symfony\AI\Agent\Toolbox\Exception\ToolNotFoundException; +use Symfony\AI\Agent\Toolbox\FaultTolerantToolbox; +use Symfony\AI\Agent\Toolbox\ToolboxInterface; +use Symfony\AI\Platform\Response\ToolCall; +use Symfony\AI\Platform\Tool\ExecutionReference; +use Symfony\AI\Platform\Tool\Tool; + +#[CoversClass(FaultTolerantToolbox::class)] +#[UsesClass(ToolCall::class)] +#[UsesClass(Tool::class)] +#[UsesClass(ExecutionReference::class)] +#[UsesClass(ToolNotFoundException::class)] +#[UsesClass(ToolExecutionException::class)] +final class FaultTolerantToolboxTest extends TestCase +{ + #[Test] + public function faultyToolExecution(): void + { + $faultyToolbox = $this->createFaultyToolbox( + fn (ToolCall $toolCall) => ToolExecutionException::executionFailed($toolCall, new \Exception('error')) + ); + + $faultTolerantToolbox = new FaultTolerantToolbox($faultyToolbox); + $expected = 'An error occurred while executing tool "tool_foo".'; + + $toolCall = new ToolCall('987654321', 'tool_foo'); + $actual = $faultTolerantToolbox->execute($toolCall); + + self::assertSame($expected, $actual); + } + + #[Test] + public function faultyToolCall(): void + { + $faultyToolbox = $this->createFaultyToolbox( + fn (ToolCall $toolCall) => ToolNotFoundException::notFoundForToolCall($toolCall) + ); + + $faultTolerantToolbox = new FaultTolerantToolbox($faultyToolbox); + $expected = 'Tool "tool_xyz" was not found, please use one of these: tool_no_params, tool_required_params'; + + $toolCall = new ToolCall('123456789', 'tool_xyz'); + $actual = $faultTolerantToolbox->execute($toolCall); + + self::assertSame($expected, $actual); + } + + private function createFaultyToolbox(\Closure $exceptionFactory): ToolboxInterface + { + return new class($exceptionFactory) implements ToolboxInterface { + public function __construct(private readonly \Closure $exceptionFactory) + { + } + + /** + * @return Tool[] + */ + public function getTools(): array + { + return [ + new Tool(new ExecutionReference(ToolNoParams::class), 'tool_no_params', 'A tool without parameters', null), + new Tool(new ExecutionReference(ToolRequiredParams::class, 'bar'), 'tool_required_params', 'A tool with required parameters', null), + ]; + } + + public function execute(ToolCall $toolCall): mixed + { + throw ($this->exceptionFactory)($toolCall); + } + }; + } +} diff --git a/src/agent/tests/Toolbox/MetadataFactory/ChainFactoryTest.php b/src/agent/tests/Toolbox/MetadataFactory/ChainFactoryTest.php new file mode 100644 index 000000000..41044dab8 --- /dev/null +++ b/src/agent/tests/Toolbox/MetadataFactory/ChainFactoryTest.php @@ -0,0 +1,101 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Tests\Toolbox\MetadataFactory; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Medium; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Agent\Tests\Fixture\Tool\ToolMisconfigured; +use Symfony\AI\Agent\Tests\Fixture\Tool\ToolMultiple; +use Symfony\AI\Agent\Tests\Fixture\Tool\ToolNoAttribute1; +use Symfony\AI\Agent\Tests\Fixture\Tool\ToolOptionalParam; +use Symfony\AI\Agent\Tests\Fixture\Tool\ToolRequiredParams; +use Symfony\AI\Agent\Toolbox\Exception\ToolConfigurationException; +use Symfony\AI\Agent\Toolbox\Exception\ToolException; +use Symfony\AI\Agent\Toolbox\ToolFactory\ChainFactory; +use Symfony\AI\Agent\Toolbox\ToolFactory\MemoryToolFactory; +use Symfony\AI\Agent\Toolbox\ToolFactory\ReflectionToolFactory; + +#[CoversClass(ChainFactory::class)] +#[Medium] +#[UsesClass(MemoryToolFactory::class)] +#[UsesClass(ReflectionToolFactory::class)] +#[UsesClass(ToolException::class)] +final class ChainFactoryTest extends TestCase +{ + private ChainFactory $factory; + + protected function setUp(): void + { + $factory1 = (new MemoryToolFactory()) + ->addTool(ToolNoAttribute1::class, 'reference', 'A reference tool') + ->addTool(ToolOptionalParam::class, 'optional_param', 'Tool with optional param', 'bar'); + $factory2 = new ReflectionToolFactory(); + + $this->factory = new ChainFactory([$factory1, $factory2]); + } + + #[Test] + public function testGetMetadataNotExistingClass(): void + { + self::expectException(ToolException::class); + self::expectExceptionMessage('The reference "NoClass" is not a valid tool.'); + + iterator_to_array($this->factory->getTool('NoClass')); + } + + #[Test] + public function testGetMetadataNotConfiguredClass(): void + { + self::expectException(ToolConfigurationException::class); + self::expectExceptionMessage(\sprintf('Method "foo" not found in tool "%s".', ToolMisconfigured::class)); + + iterator_to_array($this->factory->getTool(ToolMisconfigured::class)); + } + + #[Test] + public function testGetMetadataWithAttributeSingleHit(): void + { + $metadata = iterator_to_array($this->factory->getTool(ToolRequiredParams::class)); + + self::assertCount(1, $metadata); + } + + #[Test] + public function testGetMetadataOverwrite(): void + { + $metadata = iterator_to_array($this->factory->getTool(ToolOptionalParam::class)); + + self::assertCount(1, $metadata); + self::assertSame('optional_param', $metadata[0]->name); + self::assertSame('Tool with optional param', $metadata[0]->description); + self::assertSame('bar', $metadata[0]->reference->method); + } + + #[Test] + public function testGetMetadataWithAttributeDoubleHit(): void + { + $metadata = iterator_to_array($this->factory->getTool(ToolMultiple::class)); + + self::assertCount(2, $metadata); + } + + #[Test] + public function testGetMetadataWithMemorySingleHit(): void + { + $metadata = iterator_to_array($this->factory->getTool(ToolNoAttribute1::class)); + + self::assertCount(1, $metadata); + } +} diff --git a/src/agent/tests/Toolbox/MetadataFactory/MemoryFactoryTest.php b/src/agent/tests/Toolbox/MetadataFactory/MemoryFactoryTest.php new file mode 100644 index 000000000..d59447427 --- /dev/null +++ b/src/agent/tests/Toolbox/MetadataFactory/MemoryFactoryTest.php @@ -0,0 +1,116 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Tests\Toolbox\MetadataFactory; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Agent\Tests\Fixture\Tool\ToolNoAttribute1; +use Symfony\AI\Agent\Tests\Fixture\Tool\ToolNoAttribute2; +use Symfony\AI\Agent\Toolbox\Attribute\AsTool; +use Symfony\AI\Agent\Toolbox\Exception\ToolException; +use Symfony\AI\Agent\Toolbox\ToolFactory\MemoryToolFactory; +use Symfony\AI\Platform\Contract\JsonSchema\DescriptionParser; +use Symfony\AI\Platform\Contract\JsonSchema\Factory; +use Symfony\AI\Platform\Tool\ExecutionReference; +use Symfony\AI\Platform\Tool\Tool; + +#[CoversClass(MemoryToolFactory::class)] +#[UsesClass(AsTool::class)] +#[UsesClass(Tool::class)] +#[UsesClass(ExecutionReference::class)] +#[UsesClass(ToolException::class)] +#[UsesClass(Factory::class)] +#[UsesClass(DescriptionParser::class)] +final class MemoryFactoryTest extends TestCase +{ + #[Test] + public function getMetadataWithoutTools(): void + { + self::expectException(ToolException::class); + self::expectExceptionMessage('The reference "SomeClass" is not a valid tool.'); + + $factory = new MemoryToolFactory(); + iterator_to_array($factory->getTool('SomeClass')); // @phpstan-ignore-line Yes, this class does not exist + } + + #[Test] + public function getMetadataWithDistinctToolPerClass(): void + { + $factory = (new MemoryToolFactory()) + ->addTool(ToolNoAttribute1::class, 'happy_birthday', 'Generates birthday message') + ->addTool(new ToolNoAttribute2(), 'checkout', 'Buys a number of items per product', 'buy'); + + $metadata = iterator_to_array($factory->getTool(ToolNoAttribute1::class)); + + self::assertCount(1, $metadata); + self::assertInstanceOf(Tool::class, $metadata[0]); + self::assertSame('happy_birthday', $metadata[0]->name); + self::assertSame('Generates birthday message', $metadata[0]->description); + self::assertSame('__invoke', $metadata[0]->reference->method); + + $expectedParams = [ + 'type' => 'object', + 'properties' => [ + 'name' => ['type' => 'string', 'description' => 'the name of the person'], + 'years' => ['type' => 'integer', 'description' => 'the age of the person'], + ], + 'required' => ['name', 'years'], + 'additionalProperties' => false, + ]; + + self::assertSame($expectedParams, $metadata[0]->parameters); + } + + #[Test] + public function getMetadataWithMultipleToolsInClass(): void + { + $factory = (new MemoryToolFactory()) + ->addTool(ToolNoAttribute2::class, 'checkout', 'Buys a number of items per product', 'buy') + ->addTool(ToolNoAttribute2::class, 'cancel', 'Cancels an order', 'cancel'); + + $metadata = iterator_to_array($factory->getTool(ToolNoAttribute2::class)); + + self::assertCount(2, $metadata); + self::assertInstanceOf(Tool::class, $metadata[0]); + self::assertSame('checkout', $metadata[0]->name); + self::assertSame('Buys a number of items per product', $metadata[0]->description); + self::assertSame('buy', $metadata[0]->reference->method); + + $expectedParams = [ + 'type' => 'object', + 'properties' => [ + 'id' => ['type' => 'integer', 'description' => 'the ID of the product'], + 'amount' => ['type' => 'integer', 'description' => 'the number of products'], + ], + 'required' => ['id', 'amount'], + 'additionalProperties' => false, + ]; + self::assertSame($expectedParams, $metadata[0]->parameters); + + self::assertInstanceOf(Tool::class, $metadata[1]); + self::assertSame('cancel', $metadata[1]->name); + self::assertSame('Cancels an order', $metadata[1]->description); + self::assertSame('cancel', $metadata[1]->reference->method); + + $expectedParams = [ + 'type' => 'object', + 'properties' => [ + 'orderId' => ['type' => 'string', 'description' => 'the ID of the order'], + ], + 'required' => ['orderId'], + 'additionalProperties' => false, + ]; + self::assertSame($expectedParams, $metadata[1]->parameters); + } +} diff --git a/src/agent/tests/Toolbox/MetadataFactory/ReflectionFactoryTest.php b/src/agent/tests/Toolbox/MetadataFactory/ReflectionFactoryTest.php new file mode 100644 index 000000000..e6133cbec --- /dev/null +++ b/src/agent/tests/Toolbox/MetadataFactory/ReflectionFactoryTest.php @@ -0,0 +1,155 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Tests\Toolbox\MetadataFactory; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Agent\Tests\Fixture\Tool\ToolMultiple; +use Symfony\AI\Agent\Tests\Fixture\Tool\ToolRequiredParams; +use Symfony\AI\Agent\Tests\Fixture\Tool\ToolWrong; +use Symfony\AI\Agent\Toolbox\Attribute\AsTool; +use Symfony\AI\Agent\Toolbox\Exception\ToolConfigurationException; +use Symfony\AI\Agent\Toolbox\Exception\ToolException; +use Symfony\AI\Agent\Toolbox\ToolFactory\ReflectionToolFactory; +use Symfony\AI\Platform\Contract\JsonSchema\DescriptionParser; +use Symfony\AI\Platform\Contract\JsonSchema\Factory; +use Symfony\AI\Platform\Tool\ExecutionReference; +use Symfony\AI\Platform\Tool\Tool; + +#[CoversClass(ReflectionToolFactory::class)] +#[UsesClass(AsTool::class)] +#[UsesClass(Tool::class)] +#[UsesClass(ExecutionReference::class)] +#[UsesClass(Factory::class)] +#[UsesClass(DescriptionParser::class)] +#[UsesClass(ToolConfigurationException::class)] +#[UsesClass(ToolException::class)] +final class ReflectionFactoryTest extends TestCase +{ + private ReflectionToolFactory $factory; + + protected function setUp(): void + { + $this->factory = new ReflectionToolFactory(); + } + + #[Test] + public function invalidReferenceNonExistingClass(): void + { + self::expectException(ToolException::class); + self::expectExceptionMessage('The reference "invalid" is not a valid tool.'); + + iterator_to_array($this->factory->getTool('invalid')); // @phpstan-ignore-line Yes, this class does not exist + } + + #[Test] + public function withoutAttribute(): void + { + self::expectException(ToolException::class); + self::expectExceptionMessage(\sprintf('The class "%s" is not a tool, please add %s attribute.', ToolWrong::class, AsTool::class)); + + iterator_to_array($this->factory->getTool(ToolWrong::class)); + } + + #[Test] + public function getDefinition(): void + { + /** @var Tool[] $metadatas */ + $metadatas = iterator_to_array($this->factory->getTool(ToolRequiredParams::class)); + + self::assertToolConfiguration( + metadata: $metadatas[0], + className: ToolRequiredParams::class, + name: 'tool_required_params', + description: 'A tool with required parameters', + method: 'bar', + parameters: [ + 'type' => 'object', + 'properties' => [ + 'text' => [ + 'type' => 'string', + 'description' => 'The text given to the tool', + ], + 'number' => [ + 'type' => 'integer', + 'description' => 'A number given to the tool', + ], + ], + 'required' => ['text', 'number'], + 'additionalProperties' => false, + ], + ); + } + + #[Test] + public function getDefinitionWithMultiple(): void + { + $metadatas = iterator_to_array($this->factory->getTool(ToolMultiple::class)); + + self::assertCount(2, $metadatas); + + [$first, $second] = $metadatas; + + self::assertToolConfiguration( + metadata: $first, + className: ToolMultiple::class, + name: 'tool_hello_world', + description: 'Function to say hello', + method: 'hello', + parameters: [ + 'type' => 'object', + 'properties' => [ + 'world' => [ + 'type' => 'string', + 'description' => 'The world to say hello to', + ], + ], + 'required' => ['world'], + 'additionalProperties' => false, + ], + ); + + self::assertToolConfiguration( + metadata: $second, + className: ToolMultiple::class, + name: 'tool_required_params', + description: 'Function to say a number', + method: 'bar', + parameters: [ + 'type' => 'object', + 'properties' => [ + 'text' => [ + 'type' => 'string', + 'description' => 'The text given to the tool', + ], + 'number' => [ + 'type' => 'integer', + 'description' => 'A number given to the tool', + ], + ], + 'required' => ['text', 'number'], + 'additionalProperties' => false, + ], + ); + } + + private function assertToolConfiguration(Tool $metadata, string $className, string $name, string $description, string $method, array $parameters): void + { + self::assertSame($className, $metadata->reference->class); + self::assertSame($method, $metadata->reference->method); + self::assertSame($name, $metadata->name); + self::assertSame($description, $metadata->description); + self::assertSame($parameters, $metadata->parameters); + } +} diff --git a/src/agent/tests/Toolbox/Tool/BraveTest.php b/src/agent/tests/Toolbox/Tool/BraveTest.php new file mode 100644 index 000000000..730fbe2e9 --- /dev/null +++ b/src/agent/tests/Toolbox/Tool/BraveTest.php @@ -0,0 +1,82 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Tests\Toolbox\Tool; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Agent\Toolbox\Tool\Brave; +use Symfony\Component\HttpClient\MockHttpClient; +use Symfony\Component\HttpClient\Response\JsonMockResponse; +use Symfony\Component\HttpClient\Response\MockResponse; + +#[CoversClass(Brave::class)] +final class BraveTest extends TestCase +{ + #[Test] + public function returnsSearchResults(): void + { + $response = $this->jsonMockResponseFromFile(__DIR__.'/fixtures/brave.json'); + $httpClient = new MockHttpClient($response); + $brave = new Brave($httpClient, 'test-api-key'); + + $results = $brave('latest Dallas Cowboys game result'); + + self::assertCount(5, $results); + self::assertArrayHasKey('title', $results[0]); + self::assertSame('Dallas Cowboys Scores, Stats and Highlights - ESPN', $results[0]['title']); + self::assertArrayHasKey('description', $results[0]); + self::assertSame('Visit ESPN for Dallas Cowboys live scores, video highlights, and latest news. Find standings and the full 2024 season schedule.', $results[0]['description']); + self::assertArrayHasKey('url', $results[0]); + self::assertSame('https://www.espn.com/nfl/team/_/name/dal/dallas-cowboys', $results[0]['url']); + } + + #[Test] + public function passesCorrectParametersToApi(): void + { + $response = $this->jsonMockResponseFromFile(__DIR__.'/fixtures/brave.json'); + $httpClient = new MockHttpClient($response); + $brave = new Brave($httpClient, 'test-api-key', ['extra' => 'option']); + + $brave('test query', 10, 5); + + $request = $response->getRequestUrl(); + self::assertStringContainsString('q=test%20query', $request); + self::assertStringContainsString('count=10', $request); + self::assertStringContainsString('offset=5', $request); + self::assertStringContainsString('extra=option', $request); + + $requestOptions = $response->getRequestOptions(); + self::assertArrayHasKey('headers', $requestOptions); + self::assertContains('X-Subscription-Token: test-api-key', $requestOptions['headers']); + } + + #[Test] + public function handlesEmptyResults(): void + { + $response = new MockResponse(json_encode(['web' => ['results' => []]])); + $httpClient = new MockHttpClient($response); + $brave = new Brave($httpClient, 'test-api-key'); + + $results = $brave('this should return nothing'); + + self::assertEmpty($results); + } + + /** + * This can be replaced by `JsonMockResponse::fromFile` when dropping Symfony 6.4. + */ + private function jsonMockResponseFromFile(string $file): JsonMockResponse + { + return new JsonMockResponse(json_decode(file_get_contents($file), true)); + } +} diff --git a/src/agent/tests/Toolbox/Tool/OpenMeteoTest.php b/src/agent/tests/Toolbox/Tool/OpenMeteoTest.php new file mode 100644 index 000000000..6a9050b18 --- /dev/null +++ b/src/agent/tests/Toolbox/Tool/OpenMeteoTest.php @@ -0,0 +1,83 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Tests\Toolbox\Tool; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Agent\Toolbox\Tool\OpenMeteo; +use Symfony\Component\HttpClient\MockHttpClient; +use Symfony\Component\HttpClient\Response\JsonMockResponse; + +#[CoversClass(OpenMeteo::class)] +final class OpenMeteoTest extends TestCase +{ + #[Test] + public function current(): void + { + $response = $this->jsonMockResponseFromFile(__DIR__.'/fixtures/openmeteo-current.json'); + $httpClient = new MockHttpClient($response); + + $openMeteo = new OpenMeteo($httpClient); + + $actual = $openMeteo->current(52.52, 13.42); + $expected = [ + 'weather' => 'Overcast', + 'time' => '2024-12-21T01:15', + 'temperature' => '2.6°C', + 'wind_speed' => '10.7km/h', + ]; + + static::assertSame($expected, $actual); + } + + #[Test] + public function forecast(): void + { + $response = $this->jsonMockResponseFromFile(__DIR__.'/fixtures/openmeteo-forecast.json'); + $httpClient = new MockHttpClient($response); + + $openMeteo = new OpenMeteo($httpClient); + + $actual = $openMeteo->forecast(52.52, 13.42, 3); + $expected = [ + [ + 'weather' => 'Light Rain', + 'time' => '2024-12-21', + 'temperature_min' => '2°C', + 'temperature_max' => '6°C', + ], + [ + 'weather' => 'Light Showers', + 'time' => '2024-12-22', + 'temperature_min' => '1.3°C', + 'temperature_max' => '6.4°C', + ], + [ + 'weather' => 'Light Snow Showers', + 'time' => '2024-12-23', + 'temperature_min' => '1.5°C', + 'temperature_max' => '4.1°C', + ], + ]; + + static::assertSame($expected, $actual); + } + + /** + * This can be replaced by `JsonMockResponse::fromFile` when dropping Symfony 6.4. + */ + private function jsonMockResponseFromFile(string $file): JsonMockResponse + { + return new JsonMockResponse(json_decode(file_get_contents($file), true)); + } +} diff --git a/src/agent/tests/Toolbox/Tool/WikipediaTest.php b/src/agent/tests/Toolbox/Tool/WikipediaTest.php new file mode 100644 index 000000000..1547eb93c --- /dev/null +++ b/src/agent/tests/Toolbox/Tool/WikipediaTest.php @@ -0,0 +1,123 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Tests\Toolbox\Tool; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Agent\Toolbox\Tool\Wikipedia; +use Symfony\Component\HttpClient\MockHttpClient; +use Symfony\Component\HttpClient\Response\JsonMockResponse; + +#[CoversClass(Wikipedia::class)] +final class WikipediaTest extends TestCase +{ + #[Test] + public function searchWithResults(): void + { + $response = $this->jsonMockResponseFromFile(__DIR__.'/fixtures/wikipedia-search-result.json'); + $httpClient = new MockHttpClient($response); + + $wikipedia = new Wikipedia($httpClient); + + $actual = $wikipedia->search('current secretary of the united nations'); + $expected = <<jsonMockResponseFromFile(__DIR__.'/fixtures/wikipedia-search-empty.json'); + $httpClient = new MockHttpClient($response); + + $wikipedia = new Wikipedia($httpClient); + + $actual = $wikipedia->search('weird questions without results'); + $expected = 'No articles were found on Wikipedia.'; + + static::assertSame($expected, $actual); + } + + #[Test] + public function articleWithResult(): void + { + $response = $this->jsonMockResponseFromFile(__DIR__.'/fixtures/wikipedia-article.json'); + $httpClient = new MockHttpClient($response); + + $wikipedia = new Wikipedia($httpClient); + + $actual = $wikipedia->article('Secretary-General of the United Nations'); + $expected = <<jsonMockResponseFromFile(__DIR__.'/fixtures/wikipedia-article-redirect.json'); + $httpClient = new MockHttpClient($response); + + $wikipedia = new Wikipedia($httpClient); + + $actual = $wikipedia->article('United Nations secretary-general'); + $expected = <<jsonMockResponseFromFile(__DIR__.'/fixtures/wikipedia-article-missing.json'); + $httpClient = new MockHttpClient($response); + + $wikipedia = new Wikipedia($httpClient); + + $actual = $wikipedia->article('Blah blah blah'); + $expected = 'No article with title "Blah blah blah" was found on Wikipedia.'; + + static::assertSame($expected, $actual); + } + + /** + * This can be replaced by `JsonMockResponse::fromFile` when dropping Symfony 6.4. + */ + private function jsonMockResponseFromFile(string $file): JsonMockResponse + { + return new JsonMockResponse(json_decode(file_get_contents($file), true)); + } +} diff --git a/src/agent/tests/Toolbox/Tool/fixtures/brave.json b/src/agent/tests/Toolbox/Tool/fixtures/brave.json new file mode 100644 index 000000000..d8793382b --- /dev/null +++ b/src/agent/tests/Toolbox/Tool/fixtures/brave.json @@ -0,0 +1,276 @@ +{ + "query": { + "original": "latest Dallas Cowboys game result", + "show_strict_warning": false, + "is_navigational": false, + "is_news_breaking": false, + "spellcheck_off": true, + "country": "us", + "bad_results": false, + "should_fallback": false, + "postal_code": "", + "city": "", + "header_country": "", + "more_results_available": true, + "state": "" + }, + "mixed": { + "type": "mixed", + "main": [ + { + "type": "web", + "index": 0, + "all": false + }, + { + "type": "web", + "index": 1, + "all": false + }, + { + "type": "web", + "index": 2, + "all": false + }, + { + "type": "web", + "index": 3, + "all": false + }, + { + "type": "web", + "index": 4, + "all": false + }, + { + "type": "web", + "index": 5, + "all": false + }, + { + "type": "web", + "index": 6, + "all": false + }, + { + "type": "web", + "index": 7, + "all": false + }, + { + "type": "web", + "index": 8, + "all": false + }, + { + "type": "web", + "index": 9, + "all": false + }, + { + "type": "web", + "index": 10, + "all": false + }, + { + "type": "web", + "index": 11, + "all": false + }, + { + "type": "web", + "index": 12, + "all": false + }, + { + "type": "web", + "index": 13, + "all": false + }, + { + "type": "web", + "index": 14, + "all": false + }, + { + "type": "web", + "index": 15, + "all": false + }, + { + "type": "web", + "index": 16, + "all": false + }, + { + "type": "web", + "index": 17, + "all": false + }, + { + "type": "web", + "index": 18, + "all": false + } + ], + "top": [], + "side": [] + }, + "type": "search", + "web": { + "type": "search", + "results": [ + { + "title": "Dallas Cowboys Scores, Stats and Highlights - ESPN", + "url": "https://www.espn.com/nfl/team/_/name/dal/dallas-cowboys", + "is_source_local": false, + "is_source_both": false, + "description": "Visit ESPN for Dallas Cowboys live scores, video highlights, and latest news. Find standings and the full 2024 season schedule.", + "profile": { + "name": "ESPN", + "url": "https://www.espn.com/nfl/team/_/name/dal/dallas-cowboys", + "long_name": "Entertainment and Sports Programming Network", + "img": "https://imgs.search.brave.com/Kz1hWnjcBXLBXExGU0hCyCn2-pB94hTqPkqNv2qL9Ds/rs:fit:32:32:1:0/g:ce/aHR0cDovL2Zhdmlj/b25zLnNlYXJjaC5i/cmF2ZS5jb20vaWNv/bnMvM2MzMjQzYzM4/MGZiMjZlNDJlY2Iy/ZjM1N2RjZjMxYzhk/YWNiNmVlMGViMDRl/ZGVhYzJjMzI4OTkz/NDY0MGI4MS93d3cu/ZXNwbi5jb20v" + }, + "language": "en", + "family_friendly": true, + "type": "search_result", + "subtype": "generic", + "is_live": false, + "meta_url": { + "scheme": "https", + "netloc": "espn.com", + "hostname": "www.espn.com", + "favicon": "https://imgs.search.brave.com/Kz1hWnjcBXLBXExGU0hCyCn2-pB94hTqPkqNv2qL9Ds/rs:fit:32:32:1:0/g:ce/aHR0cDovL2Zhdmlj/b25zLnNlYXJjaC5i/cmF2ZS5jb20vaWNv/bnMvM2MzMjQzYzM4/MGZiMjZlNDJlY2Iy/ZjM1N2RjZjMxYzhk/YWNiNmVlMGViMDRl/ZGVhYzJjMzI4OTkz/NDY0MGI4MS93d3cu/ZXNwbi5jb20v", + "path": "› nfl › team › _ › name › dal › dallas-cowboys" + }, + "thumbnail": { + "src": "https://imgs.search.brave.com/yoylqmAf8Idap3k8AvVgIN8VAnBC3qYLTIPUhr-dngk/rs:fit:200:200:1:0/g:ce/aHR0cHM6Ly9hLmVz/cG5jZG4uY29tL2Nv/bWJpbmVyL2k_aW1n/PS9pL3RlYW1sb2dv/cy9uZmwvNTAwL2Rh/bC5wbmc", + "original": "https://a.espncdn.com/combiner/i?img=/i/teamlogos/nfl/500/dal.png", + "logo": true + } + }, + { + "title": "Dallas Cowboys | Official Site of the Dallas Cowboys", + "url": "https://www.dallascowboys.com/", + "is_source_local": false, + "is_source_both": false, + "description": "As Mickey Spagnola writes in his Friday column, the Cowboys have to be ready for anything in this upcoming draft as the best-laid plans can often go awry. In this week's Mick Shots, Mickey Spagnola looks at how depth options could play a part on who to pick at No. 12. Plus, a closer look at ...", + "profile": { + "name": "Dallascowboys", + "url": "https://www.dallascowboys.com/", + "long_name": "dallascowboys.com", + "img": "https://imgs.search.brave.com/jlBcXKzEJ8ZVEM7kjduly5zdZdDd3ZKJa3KtSH06rxk/rs:fit:32:32:1:0/g:ce/aHR0cDovL2Zhdmlj/b25zLnNlYXJjaC5i/cmF2ZS5jb20vaWNv/bnMvODk0YjBmYmE0/N2E2ZmM4NjgxYzI2/ZmZmMWMxODE3YTMz/MmM5YmQ4MDBkZmM3/NjFiOWNlYzczMGUz/OTg3NWRhZi93d3cu/ZGFsbGFzY293Ym95/cy5jb20v" + }, + "language": "en", + "family_friendly": true, + "type": "search_result", + "subtype": "generic", + "is_live": false, + "meta_url": { + "scheme": "https", + "netloc": "dallascowboys.com", + "hostname": "www.dallascowboys.com", + "favicon": "https://imgs.search.brave.com/jlBcXKzEJ8ZVEM7kjduly5zdZdDd3ZKJa3KtSH06rxk/rs:fit:32:32:1:0/g:ce/aHR0cDovL2Zhdmlj/b25zLnNlYXJjaC5i/cmF2ZS5jb20vaWNv/bnMvODk0YjBmYmE0/N2E2ZmM4NjgxYzI2/ZmZmMWMxODE3YTMz/MmM5YmQ4MDBkZmM3/NjFiOWNlYzczMGUz/OTg3NWRhZi93d3cu/ZGFsbGFzY293Ym95/cy5jb20v", + "path": "" + }, + "thumbnail": { + "src": "https://imgs.search.brave.com/IfNsWGGP4OLU1pxkWNTYIeZtTxjykyoRTWTaYkKDPU0/rs:fit:200:200:1:0/g:ce/aHR0cHM6Ly93d3cu/ZGFsbGFzY293Ym95/cy5jb20v", + "original": "https://www.dallascowboys.com/", + "logo": false + } + }, + { + "title": "Dallas Cowboys News, Scores, Status, Schedule - NFL - CBSSports.com", + "url": "https://www.cbssports.com/nfl/teams/DAL/dallas-cowboys/", + "is_source_local": false, + "is_source_both": false, + "description": "Get the latest news and information for the Dallas Cowboys. 2024 season schedule, scores, stats, and highlights. Find out the latest on your favorite NFL teams on CBSSports.com.", + "profile": { + "name": "Cbssports", + "url": "https://www.cbssports.com/nfl/teams/DAL/dallas-cowboys/", + "long_name": "cbssports.com", + "img": "https://imgs.search.brave.com/G8DEk0_A87RxEMyNA8Uhu5GaN1usv62iX_74SwwTHSk/rs:fit:32:32:1:0/g:ce/aHR0cDovL2Zhdmlj/b25zLnNlYXJjaC5i/cmF2ZS5jb20vaWNv/bnMvM2FlZjEzYmM3/NzkwMzQ5ZWYwMWQ3/YjJiZGM5MGMxMWFl/ZDBlNmQxMTk2N2Fm/MjljMzU2OGIzMTUz/M2Q4ZjcxNS93d3cu/Y2Jzc3BvcnRzLmNv/bS8" + }, + "language": "en", + "family_friendly": true, + "type": "search_result", + "subtype": "generic", + "is_live": false, + "meta_url": { + "scheme": "https", + "netloc": "cbssports.com", + "hostname": "www.cbssports.com", + "favicon": "https://imgs.search.brave.com/G8DEk0_A87RxEMyNA8Uhu5GaN1usv62iX_74SwwTHSk/rs:fit:32:32:1:0/g:ce/aHR0cDovL2Zhdmlj/b25zLnNlYXJjaC5i/cmF2ZS5jb20vaWNv/bnMvM2FlZjEzYmM3/NzkwMzQ5ZWYwMWQ3/YjJiZGM5MGMxMWFl/ZDBlNmQxMTk2N2Fm/MjljMzU2OGIzMTUz/M2Q4ZjcxNS93d3cu/Y2Jzc3BvcnRzLmNv/bS8", + "path": "› nfl › teams › DAL › dallas-cowboys" + }, + "thumbnail": { + "src": "https://imgs.search.brave.com/9YB61Wfb-DCqHUR_XH_Tq7Iwy9KW9qCjCaO0bKpQ_bU/rs:fit:200:200:1:0/g:ce/aHR0cHM6Ly9zcG9y/dHNmbHkuY2JzaXN0/YXRpYy5jb20vZmx5/LTA5NDQvYnVuZGxl/cy9zcG9ydHNtZWRp/YWNzcy9pbWFnZXMv/ZmFudGFzeS9kZWZh/dWx0LWFydGljbGUt/aW1hZ2UtbGFyZ2Uu/cG5n", + "original": "https://sportsfly.cbsistatic.com/fly-0944/bundles/sportsmediacss/images/fantasy/default-article-image-large.png", + "logo": false + } + }, + { + "title": "Dallas Cowboys News, Scores, Stats, Schedule | NFL.com", + "url": "https://www.nfl.com/teams/dallas-cowboys/", + "is_source_local": false, + "is_source_both": false, + "description": "Dallas Cowboys · 3rd NFC East · 7 - 10 - 0 Buy Gear · Official Website · @DallasCowboys · @DallasCowboys · @dallascowboys · @DallasCowboys · Info · Roster · Stats · Advertising · news · Apr 21, 2025 · news · Apr 18, 2025 · news · Apr 18, 2025 ·", + "profile": { + "name": "Nfl", + "url": "https://www.nfl.com/teams/dallas-cowboys/", + "long_name": "nfl.com", + "img": "https://imgs.search.brave.com/L4B2SCyb0Ao1-76nGZVpWlnyS8TkQBAEFnf4Lpb_KRY/rs:fit:32:32:1:0/g:ce/aHR0cDovL2Zhdmlj/b25zLnNlYXJjaC5i/cmF2ZS5jb20vaWNv/bnMvZmQxNWFiOGQ3/Mjc0MWRjYjI0OTQx/N2E5NTNhODVkMGQ2/OTI3NzcyODQ4MzU2/Nzg3YTJmMjJiOGMz/OTM1NjVmYy93d3cu/bmZsLmNvbS8" + }, + "language": "en", + "family_friendly": true, + "type": "search_result", + "subtype": "generic", + "is_live": false, + "meta_url": { + "scheme": "https", + "netloc": "nfl.com", + "hostname": "www.nfl.com", + "favicon": "https://imgs.search.brave.com/L4B2SCyb0Ao1-76nGZVpWlnyS8TkQBAEFnf4Lpb_KRY/rs:fit:32:32:1:0/g:ce/aHR0cDovL2Zhdmlj/b25zLnNlYXJjaC5i/cmF2ZS5jb20vaWNv/bnMvZmQxNWFiOGQ3/Mjc0MWRjYjI0OTQx/N2E5NTNhODVkMGQ2/OTI3NzcyODQ4MzU2/Nzg3YTJmMjJiOGMz/OTM1NjVmYy93d3cu/bmZsLmNvbS8", + "path": "› teams › dallas-cowboys" + }, + "thumbnail": { + "src": "https://imgs.search.brave.com/0EETrH0-svPjjtTCtvV7kIQ5DuPcD03NtVFgNG8A2KM/rs:fit:200:200:1:0/g:ce/aHR0cHM6Ly9zdGF0/aWMud3d3Lm5mbC5j/b20vdF9oZWFkc2hv/dF9kZXNrdG9wL2xl/YWd1ZS9hcGkvY2x1/YnMvbG9nb3MvREFM", + "original": "https://static.www.nfl.com/t_headshot_desktop/league/api/clubs/logos/DAL", + "logo": true + } + }, + { + "title": "Dallas Cowboys News, Videos, Schedule, Roster, Stats - Yahoo Sports", + "url": "https://sports.yahoo.com/nfl/teams/dallas/", + "is_source_local": false, + "is_source_both": false, + "description": "The team has 10 selections to get right in hopes of turning a 7-10 season in 2024 into an afterthought, and helping the Cowboys return to the playoffs. The good news is that Dallas has been one of the better drafting teams in the league, and they routinely find the right players to help win games.", + "profile": { + "name": "Yahoo!", + "url": "https://sports.yahoo.com/nfl/teams/dallas/", + "long_name": "sports.yahoo.com", + "img": "https://imgs.search.brave.com/pqn4FSmuomyVnDi_JumaeN-Milit-_D15P8bquK1CAc/rs:fit:32:32:1:0/g:ce/aHR0cDovL2Zhdmlj/b25zLnNlYXJjaC5i/cmF2ZS5jb20vaWNv/bnMvOTUxYWU5NWYz/ZWIxOGI3M2Q5MzFh/MjlmZDczOWEyMzY5/M2FhZTZiOGIzOTQ0/YzlkMGI3YTI2MmM2/ZmJmMWE2Zi9zcG9y/dHMueWFob28uY29t/Lw" + }, + "language": "en", + "family_friendly": true, + "type": "search_result", + "subtype": "generic", + "is_live": false, + "meta_url": { + "scheme": "https", + "netloc": "sports.yahoo.com", + "hostname": "sports.yahoo.com", + "favicon": "https://imgs.search.brave.com/pqn4FSmuomyVnDi_JumaeN-Milit-_D15P8bquK1CAc/rs:fit:32:32:1:0/g:ce/aHR0cDovL2Zhdmlj/b25zLnNlYXJjaC5i/cmF2ZS5jb20vaWNv/bnMvOTUxYWU5NWYz/ZWIxOGI3M2Q5MzFh/MjlmZDczOWEyMzY5/M2FhZTZiOGIzOTQ0/YzlkMGI3YTI2MmM2/ZmJmMWE2Zi9zcG9y/dHMueWFob28uY29t/Lw", + "path": "› nfl › teams › dallas" + }, + "thumbnail": { + "src": "https://imgs.search.brave.com/okMqy3l4NCCt72TSaYUZSdJcmgHc9Q1i63ttIe-rIQ0/rs:fit:200:200:1:0/g:ce/aHR0cHM6Ly9zLnlp/bWcuY29tL2l0L2Fw/aS9yZXMvMS4yL3ZX/bjVTVDlRQl95NmFi/dWl2cWdpQkEtLX5B/L1lYQndhV1E5ZVc1/bGQzTTdkejB4TWpB/d08yZzlOak13TzNF/OU1UQXcvaHR0cHM6/Ly9zLnlpbWcuY29t/L2N2L2FwaXYyL2Rl/ZmF1bHQvbmZsLzIw/MTkwNzI0LzUwMHg1/MDAvMjAxOV9EQUxf/d2JnLnBuZw", + "original": "https://s.yimg.com/it/api/res/1.2/vWn5ST9QB_y6abuivqgiBA--~A/YXBwaWQ9eW5ld3M7dz0xMjAwO2g9NjMwO3E9MTAw/https://s.yimg.com/cv/apiv2/default/nfl/20190724/500x500/2019_DAL_wbg.png", + "logo": false + } + } + ], + "family_friendly": true + } +} diff --git a/src/agent/tests/Toolbox/Tool/fixtures/openmeteo-current.json b/src/agent/tests/Toolbox/Tool/fixtures/openmeteo-current.json new file mode 100644 index 000000000..16d6cb266 --- /dev/null +++ b/src/agent/tests/Toolbox/Tool/fixtures/openmeteo-current.json @@ -0,0 +1,23 @@ +{ + "latitude": 52.52, + "longitude": 13.419998, + "generationtime_ms": 0.06508827209472656, + "utc_offset_seconds": 0, + "timezone": "GMT", + "timezone_abbreviation": "GMT", + "elevation": 40.0, + "current_units": { + "time": "iso8601", + "interval": "seconds", + "weather_code": "wmo code", + "temperature_2m": "°C", + "wind_speed_10m": "km/h" + }, + "current": { + "time": "2024-12-21T01:15", + "interval": 900, + "weather_code": 3, + "temperature_2m": 2.6, + "wind_speed_10m": 10.7 + } +} diff --git a/src/agent/tests/Toolbox/Tool/fixtures/openmeteo-forecast.json b/src/agent/tests/Toolbox/Tool/fixtures/openmeteo-forecast.json new file mode 100644 index 000000000..beb4e1413 --- /dev/null +++ b/src/agent/tests/Toolbox/Tool/fixtures/openmeteo-forecast.json @@ -0,0 +1,37 @@ +{ + "latitude": 52.52, + "longitude": 13.419998, + "generationtime_ms": 0.0629425048828125, + "utc_offset_seconds": 0, + "timezone": "GMT", + "timezone_abbreviation": "GMT", + "elevation": 38.0, + "daily_units": { + "time": "iso8601", + "weather_code": "wmo code", + "temperature_2m_max": "°C", + "temperature_2m_min": "°C" + }, + "daily": { + "time": [ + "2024-12-21", + "2024-12-22", + "2024-12-23" + ], + "weather_code": [ + 61, + 80, + 85 + ], + "temperature_2m_max": [ + 6.0, + 6.4, + 4.1 + ], + "temperature_2m_min": [ + 2.0, + 1.3, + 1.5 + ] + } +} diff --git a/src/agent/tests/Toolbox/Tool/fixtures/wikipedia-article-missing.json b/src/agent/tests/Toolbox/Tool/fixtures/wikipedia-article-missing.json new file mode 100644 index 000000000..2bea603dd --- /dev/null +++ b/src/agent/tests/Toolbox/Tool/fixtures/wikipedia-article-missing.json @@ -0,0 +1,16 @@ +{ + "batchcomplete": "", + "query": { + "pages": { + "-1": { + "ns": 0, + "title": "Blah blah blah", + "missing": "", + "contentmodel": "wikitext", + "pagelanguage": "en", + "pagelanguagehtmlcode": "en", + "pagelanguagedir": "ltr" + } + } + } +} diff --git a/src/agent/tests/Toolbox/Tool/fixtures/wikipedia-article-redirect.json b/src/agent/tests/Toolbox/Tool/fixtures/wikipedia-article-redirect.json new file mode 100644 index 000000000..01175cd27 --- /dev/null +++ b/src/agent/tests/Toolbox/Tool/fixtures/wikipedia-article-redirect.json @@ -0,0 +1,32 @@ +{ + "batchcomplete": "", + "query": { + "redirects": [ + { + "from": "United Nations secretary-general", + "to": "Secretary-General of the United Nations" + } + ], + "pages": { + "162415": { + "pageid": 162415, + "ns": 0, + "title": "Secretary-General of the United Nations", + "extract": "The secretary-general of the United Nations (UNSG or UNSECGEN) is the chief administrative officer of the United Nations and head of the United Nations Secretariat, one of the six principal organs of the United Nations. And so on.", + "contentmodel": "wikitext", + "pagelanguage": "en", + "pagelanguagehtmlcode": "en", + "pagelanguagedir": "ltr", + "touched": "2024-12-07T14:43:16Z", + "lastrevid": 1259468323, + "length": 35508, + "thumbnail": { + "source": "https:\/\/upload.wikimedia.org\/wikipedia\/commons\/thumb\/5\/52\/Emblem_of_the_United_Nations.svg\/50px-Emblem_of_the_United_Nations.svg.png", + "width": 50, + "height": 43 + }, + "pageimage": "Emblem_of_the_United_Nations.svg" + } + } + } +} diff --git a/src/agent/tests/Toolbox/Tool/fixtures/wikipedia-article.json b/src/agent/tests/Toolbox/Tool/fixtures/wikipedia-article.json new file mode 100644 index 000000000..6275e92b9 --- /dev/null +++ b/src/agent/tests/Toolbox/Tool/fixtures/wikipedia-article.json @@ -0,0 +1,26 @@ +{ + "batchcomplete": "", + "query": { + "pages": { + "162415": { + "pageid": 162415, + "ns": 0, + "title": "Secretary-General of the United Nations", + "extract": "The secretary-general of the United Nations (UNSG or UNSECGEN) is the chief administrative officer of the United Nations and head of the United Nations Secretariat, one of the six principal organs of the United Nations. And so on.", + "contentmodel": "wikitext", + "pagelanguage": "en", + "pagelanguagehtmlcode": "en", + "pagelanguagedir": "ltr", + "touched": "2024-12-07T14:43:16Z", + "lastrevid": 1259468323, + "length": 35508, + "thumbnail": { + "source": "https:\/\/upload.wikimedia.org\/wikipedia\/commons\/thumb\/5\/52\/Emblem_of_the_United_Nations.svg\/50px-Emblem_of_the_United_Nations.svg.png", + "width": 50, + "height": 43 + }, + "pageimage": "Emblem_of_the_United_Nations.svg" + } + } + } +} diff --git a/src/agent/tests/Toolbox/Tool/fixtures/wikipedia-search-empty.json b/src/agent/tests/Toolbox/Tool/fixtures/wikipedia-search-empty.json new file mode 100644 index 000000000..3a547ec2b --- /dev/null +++ b/src/agent/tests/Toolbox/Tool/fixtures/wikipedia-search-empty.json @@ -0,0 +1,11 @@ +{ + "batchcomplete": "", + "query": { + "searchinfo": { + "totalhits": 0 + }, + "search": [ + + ] + } +} diff --git a/src/agent/tests/Toolbox/Tool/fixtures/wikipedia-search-result.json b/src/agent/tests/Toolbox/Tool/fixtures/wikipedia-search-result.json new file mode 100644 index 000000000..20c725466 --- /dev/null +++ b/src/agent/tests/Toolbox/Tool/fixtures/wikipedia-search-result.json @@ -0,0 +1,104 @@ +{ + "batchcomplete": "", + "continue": { + "sroffset": 10, + "continue": "-||" + }, + "query": { + "searchinfo": { + "totalhits": 27227 + }, + "search": [ + { + "ns": 0, + "title": "Under-Secretary-General of the United Nations", + "pageid": 3223434, + "size": 15971, + "wordcount": 1569, + "snippet": "An under-secretary-general of the United Nations (USG) is a senior official within the United Nations System, normally appointed by the General Assembly", + "timestamp": "2024-11-28T08:11:08Z" + }, + { + "ns": 0, + "title": "United Nations secretary-general selection", + "pageid": 52735558, + "size": 36343, + "wordcount": 4335, + "snippet": "United Nations secretary-general selection is the process of selecting the next secretary-general of the United Nations. To be selected as secretary-general", + "timestamp": "2024-08-23T07:56:28Z" + }, + { + "ns": 0, + "title": "List of current permanent representatives to the United Nations", + "pageid": 4409476, + "size": 47048, + "wordcount": 1524, + "snippet": "is a list of the current permanent representatives to the United Nations at United Nations Headquarters, New York City. The list includes the country that", + "timestamp": "2024-11-11T03:04:58Z" + }, + { + "ns": 0, + "title": "United Nations", + "pageid": 31769, + "size": 173187, + "wordcount": 15417, + "snippet": "The United Nations (UN) is a diplomatic and political international organization with the intended purpose of maintaining international peace and security", + "timestamp": "2024-11-30T10:50:21Z" + }, + { + "ns": 0, + "title": "United Nations Secretariat", + "pageid": 162410, + "size": 23340, + "wordcount": 2389, + "snippet": "The United Nations Secretariat is one of the six principal organs of the United Nations (UN), The secretariat is the UN's executive arm. The secretariat", + "timestamp": "2024-10-07T15:57:54Z" + }, + { + "ns": 0, + "title": "Flag of the United Nations", + "pageid": 565612, + "size": 18615, + "wordcount": 1219, + "snippet": "The flag of the United Nations is a sky blue banner containing the United Nations' emblem in the centre. The emblem on the flag is coloured white; it is", + "timestamp": "2024-09-26T02:16:57Z" + }, + { + "ns": 0, + "title": "List of current members of the United States House of Representatives", + "pageid": 12498224, + "size": 262590, + "wordcount": 1704, + "snippet": "in the United States House of Representatives List of current United States senators List of members of the United States Congress by longevity of service", + "timestamp": "2024-12-06T14:58:13Z" + }, + { + "ns": 0, + "title": "Member states of the United Nations", + "pageid": 31969, + "size": 107893, + "wordcount": 8436, + "snippet": "The member states of the United Nations comprise 193 sovereign states. The United Nations (UN) is the world's largest intergovernmental organization.", + "timestamp": "2024-12-08T15:19:12Z" + }, + { + "ns": 0, + "title": "Official languages of the United Nations", + "pageid": 25948712, + "size": 57104, + "wordcount": 4976, + "snippet": "The official languages of the United Nations are the six languages used in United Nations (UN) meetings and in which the UN writes all its official documents", + "timestamp": "2024-11-11T13:41:37Z" + }, + { + "ns": 0, + "title": "United States Secretary of State", + "pageid": 32293, + "size": 18112, + "wordcount": 1513, + "snippet": "The United States secretary of state (SecState) is a member of the executive branch of the federal government and the head of the Department of State", + "timestamp": "2024-12-01T17:58:13Z" + } + ] + } +} diff --git a/src/agent/tests/Toolbox/ToolResultConverterTest.php b/src/agent/tests/Toolbox/ToolResultConverterTest.php new file mode 100644 index 000000000..4a4c41596 --- /dev/null +++ b/src/agent/tests/Toolbox/ToolResultConverterTest.php @@ -0,0 +1,66 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Tests\Toolbox; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\DataProvider; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Agent\Toolbox\ToolResultConverter; + +#[CoversClass(ToolResultConverter::class)] +final class ToolResultConverterTest extends TestCase +{ + #[Test] + #[DataProvider('provideResults')] + public function testConvert(mixed $result, ?string $expected): void + { + $converter = new ToolResultConverter(); + + self::assertSame($expected, $converter->convert($result)); + } + + public static function provideResults(): \Generator + { + yield 'null' => [null, null]; + + yield 'integer' => [42, '42']; + + yield 'float' => [42.42, '42.42']; + + yield 'array' => [['key' => 'value'], '{"key":"value"}']; + + yield 'string' => ['plain string', 'plain string']; + + yield 'datetime' => [new \DateTimeImmutable('2021-07-31 12:34:56'), '2021-07-31T12:34:56+00:00']; + + yield 'stringable' => [ + new class implements \Stringable { + public function __toString(): string + { + return 'stringable'; + } + }, + 'stringable', + ]; + + yield 'json_serializable' => [ + new class implements \JsonSerializable { + public function jsonSerialize(): array + { + return ['key' => 'value']; + } + }, + '{"key":"value"}', + ]; + } +} diff --git a/src/agent/tests/Toolbox/ToolboxTest.php b/src/agent/tests/Toolbox/ToolboxTest.php new file mode 100644 index 000000000..64efcc456 --- /dev/null +++ b/src/agent/tests/Toolbox/ToolboxTest.php @@ -0,0 +1,265 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Tests\Toolbox; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\DataProvider; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Agent\Tests\Fixture\Tool\ToolException; +use Symfony\AI\Agent\Tests\Fixture\Tool\ToolMisconfigured; +use Symfony\AI\Agent\Tests\Fixture\Tool\ToolNoAttribute1; +use Symfony\AI\Agent\Tests\Fixture\Tool\ToolNoParams; +use Symfony\AI\Agent\Tests\Fixture\Tool\ToolOptionalParam; +use Symfony\AI\Agent\Tests\Fixture\Tool\ToolRequiredParams; +use Symfony\AI\Agent\Toolbox\Attribute\AsTool; +use Symfony\AI\Agent\Toolbox\Exception\ToolConfigurationException; +use Symfony\AI\Agent\Toolbox\Exception\ToolExecutionException; +use Symfony\AI\Agent\Toolbox\Exception\ToolNotFoundException; +use Symfony\AI\Agent\Toolbox\Toolbox; +use Symfony\AI\Agent\Toolbox\ToolFactory\ChainFactory; +use Symfony\AI\Agent\Toolbox\ToolFactory\MemoryToolFactory; +use Symfony\AI\Agent\Toolbox\ToolFactory\ReflectionToolFactory; +use Symfony\AI\Platform\Contract\JsonSchema\DescriptionParser; +use Symfony\AI\Platform\Contract\JsonSchema\Factory; +use Symfony\AI\Platform\Response\ToolCall; +use Symfony\AI\Platform\Tool\ExecutionReference; +use Symfony\AI\Platform\Tool\Tool; + +#[CoversClass(Toolbox::class)] +#[UsesClass(ToolCall::class)] +#[UsesClass(AsTool::class)] +#[UsesClass(Tool::class)] +#[UsesClass(ExecutionReference::class)] +#[UsesClass(ReflectionToolFactory::class)] +#[UsesClass(MemoryToolFactory::class)] +#[UsesClass(ChainFactory::class)] +#[UsesClass(Factory::class)] +#[UsesClass(DescriptionParser::class)] +#[UsesClass(ToolConfigurationException::class)] +#[UsesClass(ToolNotFoundException::class)] +#[UsesClass(ToolExecutionException::class)] +final class ToolboxTest extends TestCase +{ + private Toolbox $toolbox; + + protected function setUp(): void + { + $this->toolbox = new Toolbox(new ReflectionToolFactory(), [ + new ToolRequiredParams(), + new ToolOptionalParam(), + new ToolNoParams(), + new ToolException(), + ]); + } + + #[Test] + public function getTools(): void + { + $actual = $this->toolbox->getTools(); + + $toolRequiredParams = new Tool( + new ExecutionReference(ToolRequiredParams::class, 'bar'), + 'tool_required_params', + 'A tool with required parameters', + [ + 'type' => 'object', + 'properties' => [ + 'text' => [ + 'type' => 'string', + 'description' => 'The text given to the tool', + ], + 'number' => [ + 'type' => 'integer', + 'description' => 'A number given to the tool', + ], + ], + 'required' => ['text', 'number'], + 'additionalProperties' => false, + ], + ); + + $toolOptionalParam = new Tool( + new ExecutionReference(ToolOptionalParam::class, 'bar'), + 'tool_optional_param', + 'A tool with one optional parameter', + [ + 'type' => 'object', + 'properties' => [ + 'text' => [ + 'type' => 'string', + 'description' => 'The text given to the tool', + ], + 'number' => [ + 'type' => 'integer', + 'description' => 'A number given to the tool', + ], + ], + 'required' => ['text'], + 'additionalProperties' => false, + ], + ); + + $toolNoParams = new Tool( + new ExecutionReference(ToolNoParams::class), + 'tool_no_params', + 'A tool without parameters', + ); + + $toolException = new Tool( + new ExecutionReference(ToolException::class, 'bar'), + 'tool_exception', + 'This tool is broken', + ); + + $expected = [ + $toolRequiredParams, + $toolOptionalParam, + $toolNoParams, + $toolException, + ]; + + self::assertEquals($expected, $actual); + } + + #[Test] + public function executeWithUnknownTool(): void + { + self::expectException(ToolNotFoundException::class); + self::expectExceptionMessage('Tool not found for call: foo_bar_baz'); + + $this->toolbox->execute(new ToolCall('call_1234', 'foo_bar_baz')); + } + + #[Test] + public function executeWithMisconfiguredTool(): void + { + self::expectException(ToolConfigurationException::class); + self::expectExceptionMessage('Method "foo" not found in tool "Symfony\AI\Agent\Tests\Fixture\Tool\ToolMisconfigured".'); + + $toolbox = new Toolbox(new ReflectionToolFactory(), [new ToolMisconfigured()]); + + $toolbox->execute(new ToolCall('call_1234', 'tool_misconfigured')); + } + + #[Test] + public function executeWithException(): void + { + self::expectException(ToolExecutionException::class); + self::expectExceptionMessage('Execution of tool "tool_exception" failed with error: Tool error.'); + + $this->toolbox->execute(new ToolCall('call_1234', 'tool_exception')); + } + + #[Test] + #[DataProvider('executeProvider')] + public function execute(string $expected, string $toolName, array $toolPayload = []): void + { + self::assertSame( + $expected, + $this->toolbox->execute(new ToolCall('call_1234', $toolName, $toolPayload)), + ); + } + + /** + * @return iterable + */ + public static function executeProvider(): iterable + { + yield 'tool_required_params' => [ + 'Hello says "3".', + 'tool_required_params', + ['text' => 'Hello', 'number' => 3], + ]; + } + + #[Test] + public function toolboxMapWithMemoryFactory(): void + { + $memoryFactory = (new MemoryToolFactory()) + ->addTool(ToolNoAttribute1::class, 'happy_birthday', 'Generates birthday message'); + + $toolbox = new Toolbox($memoryFactory, [new ToolNoAttribute1()]); + $expected = [ + new Tool( + new ExecutionReference(ToolNoAttribute1::class, '__invoke'), + 'happy_birthday', + 'Generates birthday message', + [ + 'type' => 'object', + 'properties' => [ + 'name' => [ + 'type' => 'string', + 'description' => 'the name of the person', + ], + 'years' => [ + 'type' => 'integer', + 'description' => 'the age of the person', + ], + ], + 'required' => ['name', 'years'], + 'additionalProperties' => false, + ], + ), + ]; + + self::assertEquals($expected, $toolbox->getTools()); + } + + #[Test] + public function toolboxExecutionWithMemoryFactory(): void + { + $memoryFactory = (new MemoryToolFactory()) + ->addTool(ToolNoAttribute1::class, 'happy_birthday', 'Generates birthday message'); + + $toolbox = new Toolbox($memoryFactory, [new ToolNoAttribute1()]); + $response = $toolbox->execute(new ToolCall('call_1234', 'happy_birthday', ['name' => 'John', 'years' => 30])); + + self::assertSame('Happy Birthday, John! You are 30 years old.', $response); + } + + #[Test] + public function toolboxMapWithOverrideViaChain(): void + { + $factory1 = (new MemoryToolFactory()) + ->addTool(ToolOptionalParam::class, 'optional_param', 'Tool with optional param', 'bar'); + $factory2 = new ReflectionToolFactory(); + + $toolbox = new Toolbox(new ChainFactory([$factory1, $factory2]), [new ToolOptionalParam()]); + + $expected = [ + new Tool( + new ExecutionReference(ToolOptionalParam::class, 'bar'), + 'optional_param', + 'Tool with optional param', + [ + 'type' => 'object', + 'properties' => [ + 'text' => [ + 'type' => 'string', + 'description' => 'The text given to the tool', + ], + 'number' => [ + 'type' => 'integer', + 'description' => 'A number given to the tool', + ], + ], + 'required' => ['text'], + 'additionalProperties' => false, + ], + ), + ]; + + self::assertEquals($expected, $toolbox->getTools()); + } +} diff --git a/src/ai-bundle/.gitattributes b/src/ai-bundle/.gitattributes new file mode 100644 index 000000000..ec8c01802 --- /dev/null +++ b/src/ai-bundle/.gitattributes @@ -0,0 +1,6 @@ +/.github export-ignore +/tests export-ignore +.gitattributes export-ignore +.gitignore export-ignore +phpstan.dist.neon export-ignore +phpunit.xml.dist export-ignore diff --git a/src/ai-bundle/.github/PULL_REQUEST_TEMPLATE.md b/src/ai-bundle/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 000000000..fcb87228a --- /dev/null +++ b/src/ai-bundle/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,8 @@ +Please do not submit any Pull Requests here. They will be closed. +--- + +Please submit your PR here instead: +https://github.com/symfony/ai + +This repository is what we call a "subtree split": a read-only subset of that main repository. +We're looking forward to your PR there! diff --git a/src/ai-bundle/.github/workflows/close-pull-request.yml b/src/ai-bundle/.github/workflows/close-pull-request.yml new file mode 100644 index 000000000..207153fd5 --- /dev/null +++ b/src/ai-bundle/.github/workflows/close-pull-request.yml @@ -0,0 +1,20 @@ +name: Close Pull Request + +on: + pull_request_target: + types: [opened] + +jobs: + run: + runs-on: ubuntu-latest + steps: + - uses: superbrothers/close-pull-request@v3 + with: + comment: | + Thanks for your Pull Request! We love contributions. + + However, you should instead open your PR on the main repository: + https://github.com/symfony/ai + + This repository is what we call a "subtree split": a read-only subset of that main repository. + We're looking forward to your PR there! diff --git a/src/ai-bundle/.gitignore b/src/ai-bundle/.gitignore new file mode 100644 index 000000000..e90f5de2e --- /dev/null +++ b/src/ai-bundle/.gitignore @@ -0,0 +1,5 @@ +vendor +composer.lock +.php-cs-fixer.cache +.phpunit.cache +coverage diff --git a/src/ai-bundle/LICENSE b/src/ai-bundle/LICENSE new file mode 100644 index 000000000..bc38d714e --- /dev/null +++ b/src/ai-bundle/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2025-present Fabien Potencier + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is furnished +to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/src/ai-bundle/README.md b/src/ai-bundle/README.md new file mode 100644 index 000000000..b20243c46 --- /dev/null +++ b/src/ai-bundle/README.md @@ -0,0 +1,176 @@ +# Symfony AI Bundle + +Symfony integration bundle for [symfony/ai](https://github.com/symfony/ai) components. + +## Installation + +```bash +composer require symfony/ai-bundle +``` + +## Configuration + +### Simple Example with OpenAI + +```yaml +# config/packages/ai.yaml +ai: + platform: + openai: + api_key: '%env(OPENAI_API_KEY)%' + agent: + default: + model: + name: 'GPT' +``` + +### Advanced Example with Anthropic, Azure, Google and multiple agents +```yaml +# config/packages/ai.yaml +ai: + platform: + anthropic: + api_key: '%env(ANTHROPIC_API_KEY)%' + azure: + # multiple deployments possible + gpt_deployment: + base_url: '%env(AZURE_OPENAI_BASEURL)%' + deployment: '%env(AZURE_OPENAI_GPT)%' + api_key: '%env(AZURE_OPENAI_KEY)%' + api_version: '%env(AZURE_GPT_VERSION)%' + google: + api_key: '%env(GOOGLE_API_KEY)%' + agent: + rag: + platform: 'symfony_ai.platform.azure.gpt_deployment' + structured_output: false # Disables support for "output_structure" option, default is true + model: + name: 'GPT' + version: 'gpt-4o-mini' + system_prompt: 'You are a helpful assistant that can answer questions.' # The default system prompt of the agent + include_tools: true # Include tool definitions at the end of the system prompt + tools: + # Referencing a service with #[AsTool] attribute + - 'Symfony\AI\Agent\Toolbox\Tool\SimilaritySearch' + + # Referencing a service without #[AsTool] attribute + - service: 'App\Agent\Tool\CompanyName' + name: 'company_name' + description: 'Provides the name of your company' + method: 'foo' # Optional with default value '__invoke' + + # Referencing an agent => agent uses agent 🤯 + - service: 'symfony_ai.agent.research' + name: 'wikipedia_research' + description: 'Can research on Wikipedia' + is_agent: true + research: + platform: 'symfony_ai.platform.anthropic' + model: + name: 'Claude' + tools: # If undefined, all tools are injected into the agent, use "tools: false" to disable tools. + - 'Symfony\AI\Agent\Toolbox\Tool\Wikipedia' + fault_tolerant_toolbox: false # Disables fault tolerant toolbox, default is true + store: + # also azure_search, mongodb and pinecone are supported as store type + chroma_db: + # multiple collections possible per type + default: + collection: 'my_collection' + embedder: + default: + # platform: 'symfony_ai.platform.anthropic' + # store: 'symfony_ai.store.chroma_db.default' + model: + name: 'Embeddings' + version: 'text-embedding-ada-002' +``` + +## Usage + +### Agent Service + +Use the `Agent` service to leverage GPT: +```php +use Symfony\AI\Agent\AgentInterface; +use Symfony\AI\Platform\Message\Message; +use Symfony\AI\Platform\Message\MessageBag; + +final readonly class MyService +{ + public function __construct( + private AgentInterface $agent, + ) { + } + + public function submit(string $message): string + { + $messages = new MessageBag( + Message::forSystem('Speak like a pirate.'), + Message::ofUser($message), + ); + + return $this->agent->call($messages); + } +} +``` + +### Register Tools + +To use existing tools, you can register them as a service: +```yaml +services: + _defaults: + autowire: true + autoconfigure: true + + Symfony\AI\Agent\Toolbox\Tool\Clock: ~ + Symfony\AI\Agent\Toolbox\Tool\OpenMeteo: ~ + Symfony\AI\Agent\Toolbox\Tool\SerpApi: + $apiKey: '%env(SERP_API_KEY)%' + Symfony\AI\Agent\Toolbox\Tool\SimilaritySearch: ~ + Symfony\AI\Agent\Toolbox\Tool\Tavily: + $apiKey: '%env(TAVILY_API_KEY)%' + Symfony\AI\Agent\Toolbox\Tool\Wikipedia: ~ + Symfony\AI\Agent\Toolbox\Tool\YouTubeTranscriber: ~ +``` + +Custom tools can be registered by using the `#[AsTool]` attribute: + +```php +use Symfony\AI\Agent\Toolbox\Attribute\AsTool; + +#[AsTool('company_name', 'Provides the name of your company')] +final class CompanyName +{ + public function __invoke(): string + { + return 'ACME Corp.' + } +} +``` + +The agent configuration by default will inject all known tools into the agent. + +To disable this behavior, set the `tools` option to `false`: +```yaml +ai: + agent: + my_agent: + tools: false +``` + +To inject only specific tools, list them in the configuration: +```yaml +ai: + agent: + my_agent: + tools: + - 'Symfony\AI\Agent\Toolbox\Tool\SimilaritySearch' +``` + +### Profiler + +The profiler panel provides insights into the agent's execution: + +![Profiler](./profiler.png) diff --git a/src/ai-bundle/composer.json b/src/ai-bundle/composer.json new file mode 100644 index 000000000..38c56ac5c --- /dev/null +++ b/src/ai-bundle/composer.json @@ -0,0 +1,44 @@ +{ + "name": "symfony/ai-bundle", + "type": "symfony-bundle", + "description": "Symfony integration bundle for Symfony's AI components", + "license": "MIT", + "authors": [ + { + "name": "Christopher Hertel", + "email": "mail@christopher-hertel.de" + }, + { + "name": "Oskar Stark", + "email": "oskarstark@googlemail.com" + } + ], + "require": { + "php": ">=8.2", + "symfony/ai-agent": "dev-main", + "symfony/ai-platform": "dev-main", + "symfony/ai-store": "dev-main", + "symfony/config": "^6.4 || ^7.0", + "symfony/dependency-injection": "^6.4 || ^7.0", + "symfony/framework-bundle": "^6.4 || ^7.0", + "symfony/string": "^6.4 || ^7.0" + }, + "require-dev": { + "phpstan/phpstan": "^2.1", + "phpunit/phpunit": "^11.5", + "rector/rector": "^2.0" + }, + "config": { + "sort-packages": true + }, + "autoload": { + "psr-4": { + "Symfony\\AI\\AIBundle\\": "src/" + } + }, + "autoload-dev": { + "psr-4": { + "Symfony\\AI\\AIBundle\\Tests\\": "tests/" + } + } +} diff --git a/src/ai-bundle/phpstan.dist.neon b/src/ai-bundle/phpstan.dist.neon new file mode 100644 index 000000000..e7fd00638 --- /dev/null +++ b/src/ai-bundle/phpstan.dist.neon @@ -0,0 +1,8 @@ +parameters: + level: 6 + paths: + - src/ + - tests/ + excludePaths: + analyse: + - src/DependencyInjection/Configuration.php diff --git a/src/ai-bundle/phpunit.xml.dist b/src/ai-bundle/phpunit.xml.dist new file mode 100644 index 000000000..4e9e3a684 --- /dev/null +++ b/src/ai-bundle/phpunit.xml.dist @@ -0,0 +1,24 @@ + + + + + tests + + + + + + src + + + diff --git a/src/ai-bundle/profiler.png b/src/ai-bundle/profiler.png new file mode 100644 index 000000000..6aa2b5fe1 Binary files /dev/null and b/src/ai-bundle/profiler.png differ diff --git a/src/ai-bundle/src/AIBundle.php b/src/ai-bundle/src/AIBundle.php new file mode 100644 index 000000000..9270fc1a4 --- /dev/null +++ b/src/ai-bundle/src/AIBundle.php @@ -0,0 +1,18 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\AIBundle; + +use Symfony\Component\HttpKernel\Bundle\Bundle; + +final class AIBundle extends Bundle +{ +} diff --git a/src/ai-bundle/src/DependencyInjection/AIExtension.php b/src/ai-bundle/src/DependencyInjection/AIExtension.php new file mode 100644 index 000000000..b4281c84c --- /dev/null +++ b/src/ai-bundle/src/DependencyInjection/AIExtension.php @@ -0,0 +1,456 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\AIBundle\DependencyInjection; + +use Symfony\AI\Agent\Agent; +use Symfony\AI\Agent\AgentInterface; +use Symfony\AI\Agent\InputProcessor\SystemPromptInputProcessor; +use Symfony\AI\Agent\InputProcessorInterface; +use Symfony\AI\Agent\OutputProcessorInterface; +use Symfony\AI\Agent\StructuredOutput\AgentProcessor as StructureOutputProcessor; +use Symfony\AI\Agent\Toolbox\AgentProcessor as ToolProcessor; +use Symfony\AI\Agent\Toolbox\Attribute\AsTool; +use Symfony\AI\Agent\Toolbox\FaultTolerantToolbox; +use Symfony\AI\Agent\Toolbox\Tool\Agent as AgentTool; +use Symfony\AI\Agent\Toolbox\ToolFactory\ChainFactory; +use Symfony\AI\Agent\Toolbox\ToolFactory\MemoryToolFactory; +use Symfony\AI\Agent\Toolbox\ToolFactory\ReflectionToolFactory; +use Symfony\AI\AIBundle\Profiler\DataCollector; +use Symfony\AI\AIBundle\Profiler\TraceablePlatform; +use Symfony\AI\AIBundle\Profiler\TraceableToolbox; +use Symfony\AI\Platform\Bridge\Anthropic\Claude; +use Symfony\AI\Platform\Bridge\Anthropic\PlatformFactory as AnthropicPlatformFactory; +use Symfony\AI\Platform\Bridge\Azure\OpenAI\PlatformFactory as AzureOpenAIPlatformFactory; +use Symfony\AI\Platform\Bridge\Google\Gemini; +use Symfony\AI\Platform\Bridge\Google\PlatformFactory as GooglePlatformFactory; +use Symfony\AI\Platform\Bridge\Meta\Llama; +use Symfony\AI\Platform\Bridge\OpenAI\Embeddings; +use Symfony\AI\Platform\Bridge\OpenAI\GPT; +use Symfony\AI\Platform\Bridge\OpenAI\PlatformFactory as OpenAIPlatformFactory; +use Symfony\AI\Platform\Bridge\Voyage\Voyage; +use Symfony\AI\Platform\ModelClientInterface; +use Symfony\AI\Platform\Platform; +use Symfony\AI\Platform\PlatformInterface; +use Symfony\AI\Platform\ResponseConverterInterface; +use Symfony\AI\Store\Bridge\Azure\SearchStore as AzureSearchStore; +use Symfony\AI\Store\Bridge\ChromaDB\Store as ChromaDBStore; +use Symfony\AI\Store\Bridge\MongoDB\Store as MongoDBStore; +use Symfony\AI\Store\Bridge\Pinecone\Store as PineconeStore; +use Symfony\AI\Store\Embedder; +use Symfony\AI\Store\StoreInterface; +use Symfony\AI\Store\VectorStoreInterface; +use Symfony\Component\Config\FileLocator; +use Symfony\Component\DependencyInjection\ChildDefinition; +use Symfony\Component\DependencyInjection\ContainerBuilder; +use Symfony\Component\DependencyInjection\Definition; +use Symfony\Component\DependencyInjection\Extension\Extension; +use Symfony\Component\DependencyInjection\Loader\PhpFileLoader; +use Symfony\Component\DependencyInjection\Reference; + +use function Symfony\Component\String\u; + +final class AIExtension extends Extension +{ + public function load(array $configs, ContainerBuilder $container): void + { + $loader = new PhpFileLoader($container, new FileLocator(\dirname(__DIR__).'/Resources/config')); + $loader->load('services.php'); + + $configuration = new Configuration(); + $config = $this->processConfiguration($configuration, $configs); + foreach ($config['platform'] ?? [] as $type => $platform) { + $this->processPlatformConfig($type, $platform, $container); + } + $platforms = array_keys($container->findTaggedServiceIds('symfony_ai.platform')); + if (1 === \count($platforms)) { + $container->setAlias(PlatformInterface::class, reset($platforms)); + } + if ($container->getParameter('kernel.debug')) { + foreach ($platforms as $platform) { + $traceablePlatformDefinition = (new Definition(TraceablePlatform::class)) + ->setDecoratedService($platform) + ->setAutowired(true) + ->addTag('symfony_ai.traceable_platform'); + $suffix = u($platform)->afterLast('.')->toString(); + $container->setDefinition('symfony_ai.traceable_platform.'.$suffix, $traceablePlatformDefinition); + } + } + + foreach ($config['agent'] as $agentName => $agent) { + $this->processAgentConfig($agentName, $agent, $container); + } + if (1 === \count($config['agent']) && isset($agentName)) { + $container->setAlias(AgentInterface::class, 'symfony_ai.agent.'.$agentName); + } + + foreach ($config['store'] ?? [] as $type => $store) { + $this->processStoreConfig($type, $store, $container); + } + $stores = array_keys($container->findTaggedServiceIds('symfony_ai.store')); + if (1 === \count($stores)) { + $container->setAlias(VectorStoreInterface::class, reset($stores)); + $container->setAlias(StoreInterface::class, reset($stores)); + } + + foreach ($config['embedder'] as $embedderName => $embedder) { + $this->processEmbedderConfig($embedderName, $embedder, $container); + } + if (1 === \count($config['embedder']) && isset($embedderName)) { + $container->setAlias(Embedder::class, 'symfony_ai.embedder.'.$embedderName); + } + + $container->registerAttributeForAutoconfiguration(AsTool::class, static function (ChildDefinition $definition, AsTool $attribute): void { + $definition->addTag('symfony_ai.tool', [ + 'name' => $attribute->name, + 'description' => $attribute->description, + 'method' => $attribute->method, + ]); + }); + + $container->registerForAutoconfiguration(InputProcessorInterface::class) + ->addTag('symfony_ai.agent.input_processor'); + $container->registerForAutoconfiguration(OutputProcessorInterface::class) + ->addTag('symfony_ai.agent.output_processor'); + $container->registerForAutoconfiguration(ModelClientInterface::class) + ->addTag('symfony_ai.platform.model_client'); + $container->registerForAutoconfiguration(ResponseConverterInterface::class) + ->addTag('symfony_ai.platform.response_converter'); + + if (false === $container->getParameter('kernel.debug')) { + $container->removeDefinition(DataCollector::class); + $container->removeDefinition(TraceableToolbox::class); + } + } + + /** + * @param array $platform + */ + private function processPlatformConfig(string $type, array $platform, ContainerBuilder $container): void + { + if ('anthropic' === $type) { + $platformId = 'symfony_ai.platform.anthropic'; + $definition = (new Definition(Platform::class)) + ->setFactory(AnthropicPlatformFactory::class.'::create') + ->setAutowired(true) + ->setLazy(true) + ->addTag('proxy', ['interface' => PlatformInterface::class]) + ->setArguments([ + '$apiKey' => $platform['api_key'], + ]) + ->addTag('symfony_ai.platform'); + + if (isset($platform['version'])) { + $definition->replaceArgument('$version', $platform['version']); + } + + $container->setDefinition($platformId, $definition); + + return; + } + + if ('azure' === $type) { + foreach ($platform as $name => $config) { + $platformId = 'symfony_ai.platform.azure.'.$name; + $definition = (new Definition(Platform::class)) + ->setFactory(AzureOpenAIPlatformFactory::class.'::create') + ->setAutowired(true) + ->setLazy(true) + ->addTag('proxy', ['interface' => PlatformInterface::class]) + ->setArguments([ + '$baseUrl' => $config['base_url'], + '$deployment' => $config['deployment'], + '$apiVersion' => $config['api_version'], + '$apiKey' => $config['api_key'], + ]) + ->addTag('symfony_ai.platform'); + + $container->setDefinition($platformId, $definition); + } + + return; + } + + if ('google' === $type) { + $platformId = 'symfony_ai.platform.google'; + $definition = (new Definition(Platform::class)) + ->setFactory(GooglePlatformFactory::class.'::create') + ->setAutowired(true) + ->setLazy(true) + ->addTag('proxy', ['interface' => PlatformInterface::class]) + ->setArguments(['$apiKey' => $platform['api_key']]) + ->addTag('symfony_ai.platform'); + + $container->setDefinition($platformId, $definition); + + return; + } + + if ('openai' === $type) { + $platformId = 'symfony_ai.platform.openai'; + $definition = (new Definition(Platform::class)) + ->setFactory(OpenAIPlatformFactory::class.'::create') + ->setAutowired(true) + ->setLazy(true) + ->addTag('proxy', ['interface' => PlatformInterface::class]) + ->setArguments(['$apiKey' => $platform['api_key']]) + ->addTag('symfony_ai.platform'); + + $container->setDefinition($platformId, $definition); + + return; + } + + throw new \InvalidArgumentException(\sprintf('Platform "%s" is not supported for configuration via bundle at this point.', $type)); + } + + /** + * @param array $config + */ + private function processAgentConfig(string $name, array $config, ContainerBuilder $container): void + { + // MODEL + ['name' => $modelName, 'version' => $version, 'options' => $options] = $config['model']; + + $modelClass = match (strtolower((string) $modelName)) { + 'gpt' => GPT::class, + 'claude' => Claude::class, + 'llama' => Llama::class, + 'gemini' => Gemini::class, + default => throw new \InvalidArgumentException(\sprintf('Model "%s" is not supported.', $modelName)), + }; + $modelDefinition = new Definition($modelClass); + if (null !== $version) { + $modelDefinition->setArgument('$name', $version); + } + if (0 !== \count($options)) { + $modelDefinition->setArgument('$options', $options); + } + $modelDefinition->addTag('symfony_ai.model.language_model'); + $container->setDefinition('symfony_ai.agent.'.$name.'.model', $modelDefinition); + + // AGENT + $agentDefinition = (new Definition(Agent::class)) + ->setAutowired(true) + ->setArgument('$platform', new Reference($config['platform'])) + ->setArgument('$model', new Reference('symfony_ai.agent.'.$name.'.model')); + + $inputProcessors = []; + $outputProcessors = []; + + // TOOL & PROCESSOR + if ($config['tools']['enabled']) { + // Create specific toolbox and process if tools are explicitly defined + if (0 !== \count($config['tools']['services'])) { + $memoryFactoryDefinition = new Definition(MemoryToolFactory::class); + $container->setDefinition('symfony_ai.toolbox.'.$name.'.memory_factory', $memoryFactoryDefinition); + $agentFactoryDefinition = new Definition(ChainFactory::class, [ + '$factories' => [new Reference('symfony_ai.toolbox.'.$name.'.memory_factory'), new Reference(ReflectionToolFactory::class)], + ]); + $container->setDefinition('symfony_ai.toolbox.'.$name.'.agent_factory', $agentFactoryDefinition); + + $tools = []; + foreach ($config['tools']['services'] as $tool) { + $reference = new Reference($tool['service']); + // We use the memory factory in case method, description and name are set + if (isset($tool['name'], $tool['description'])) { + if ($tool['is_agent']) { + $agentWrapperDefinition = new Definition(AgentTool::class, ['$agent' => $reference]); + $container->setDefinition('symfony_ai.toolbox.'.$name.'.agent_wrapper.'.$tool['name'], $agentWrapperDefinition); + $reference = new Reference('symfony_ai.toolbox.'.$name.'.agent_wrapper.'.$tool['name']); + } + $memoryFactoryDefinition->addMethodCall('addTool', [$reference, $tool['name'], $tool['description'], $tool['method'] ?? '__invoke']); + } + $tools[] = $reference; + } + + $toolboxDefinition = (new ChildDefinition('symfony_ai.toolbox.abstract')) + ->replaceArgument('$toolFactory', new Reference('symfony_ai.toolbox.'.$name.'.agent_factory')) + ->replaceArgument('$tools', $tools); + $container->setDefinition('symfony_ai.toolbox.'.$name, $toolboxDefinition); + + if ($config['fault_tolerant_toolbox']) { + $faultTolerantToolboxDefinition = (new Definition('symfony_ai.fault_tolerant_toolbox.'.$name)) + ->setClass(FaultTolerantToolbox::class) + ->setAutowired(true) + ->setDecoratedService('symfony_ai.toolbox.'.$name); + $container->setDefinition('symfony_ai.fault_tolerant_toolbox.'.$name, $faultTolerantToolboxDefinition); + } + + if ($container->getParameter('kernel.debug')) { + $traceableToolboxDefinition = (new Definition('symfony_ai.traceable_toolbox.'.$name)) + ->setClass(TraceableToolbox::class) + ->setAutowired(true) + ->setDecoratedService('symfony_ai.toolbox.'.$name) + ->addTag('symfony_ai.traceable_toolbox'); + $container->setDefinition('symfony_ai.traceable_toolbox.'.$name, $traceableToolboxDefinition); + } + + $toolProcessorDefinition = (new ChildDefinition('symfony_ai.tool.agent_processor.abstract')) + ->replaceArgument('$toolbox', new Reference('symfony_ai.toolbox.'.$name)); + $container->setDefinition('symfony_ai.tool.agent_processor.'.$name, $toolProcessorDefinition); + + $inputProcessors[] = new Reference('symfony_ai.tool.agent_processor.'.$name); + $outputProcessors[] = new Reference('symfony_ai.tool.agent_processor.'.$name); + } else { + $inputProcessors[] = new Reference(ToolProcessor::class); + $outputProcessors[] = new Reference(ToolProcessor::class); + } + } + + // STRUCTURED OUTPUT + if ($config['structured_output']) { + $inputProcessors[] = new Reference(StructureOutputProcessor::class); + $outputProcessors[] = new Reference(StructureOutputProcessor::class); + } + + // SYSTEM PROMPT + if (\is_string($config['system_prompt'])) { + $systemPromptInputProcessorDefinition = new Definition(SystemPromptInputProcessor::class); + $systemPromptInputProcessorDefinition + ->setAutowired(true) + ->setArguments([ + '$systemPrompt' => $config['system_prompt'], + '$toolbox' => $config['include_tools'] ? new Reference('symfony_ai.toolbox.'.$name) : null, + ]); + + $inputProcessors[] = $systemPromptInputProcessorDefinition; + } + + $agentDefinition + ->setArgument('$inputProcessors', $inputProcessors) + ->setArgument('$outputProcessors', $outputProcessors); + + $container->setDefinition('symfony_ai.agent.'.$name, $agentDefinition); + } + + /** + * @param array $stores + */ + private function processStoreConfig(string $type, array $stores, ContainerBuilder $container): void + { + if ('azure_search' === $type) { + foreach ($stores as $name => $store) { + $arguments = [ + '$endpointUrl' => $store['endpoint'], + '$apiKey' => $store['api_key'], + '$indexName' => $store['index_name'], + '$apiVersion' => $store['api_version'], + ]; + + if (\array_key_exists('vector_field', $store)) { + $arguments['$vectorFieldName'] = $store['vector_field']; + } + + $definition = new Definition(AzureSearchStore::class); + $definition + ->setAutowired(true) + ->addTag('symfony_ai.store') + ->setArguments($arguments); + + $container->setDefinition('symfony_ai.store.'.$type.'.'.$name, $definition); + } + } + + if ('chroma_db' === $type) { + foreach ($stores as $name => $store) { + $definition = new Definition(ChromaDBStore::class); + $definition + ->setAutowired(true) + ->setArgument('$collectionName', $store['collection']) + ->addTag('symfony_ai.store'); + + $container->setDefinition('symfony_ai.store.'.$type.'.'.$name, $definition); + } + } + + if ('mongodb' === $type) { + foreach ($stores as $name => $store) { + $arguments = [ + '$databaseName' => $store['database'], + '$collectionName' => $store['collection'], + '$indexName' => $store['index_name'], + ]; + + if (\array_key_exists('vector_field', $store)) { + $arguments['$vectorFieldName'] = $store['vector_field']; + } + + if (\array_key_exists('bulk_write', $store)) { + $arguments['$bulkWrite'] = $store['bulk_write']; + } + + $definition = new Definition(MongoDBStore::class); + $definition + ->setAutowired(true) + ->addTag('symfony_ai.store') + ->setArguments($arguments); + + $container->setDefinition('symfony_ai.store.'.$type.'.'.$name, $definition); + } + } + + if ('pinecone' === $type) { + foreach ($stores as $name => $store) { + $arguments = [ + '$namespace' => $store['namespace'], + ]; + + if (\array_key_exists('filter', $store)) { + $arguments['$filter'] = $store['filter']; + } + + if (\array_key_exists('top_k', $store)) { + $arguments['$topK'] = $store['top_k']; + } + + $definition = new Definition(PineconeStore::class); + $definition + ->setAutowired(true) + ->addTag('symfony_ai.store') + ->setArguments($arguments); + + $container->setDefinition('symfony_ai.store.'.$type.'.'.$name, $definition); + } + } + } + + /** + * @param array $config + */ + private function processEmbedderConfig(int|string $name, array $config, ContainerBuilder $container): void + { + ['name' => $modelName, 'version' => $version, 'options' => $options] = $config['model']; + + $modelClass = match (strtolower((string) $modelName)) { + 'embeddings' => Embeddings::class, + 'voyage' => Voyage::class, + default => throw new \InvalidArgumentException(\sprintf('Model "%s" is not supported.', $modelName)), + }; + $modelDefinition = (new Definition($modelClass)); + if (null !== $version) { + $modelDefinition->setArgument('$name', $version); + } + if (0 !== \count($options)) { + $modelDefinition->setArgument('$options', $options); + } + $modelDefinition->addTag('symfony_ai.model.embeddings_model'); + $container->setDefinition('symfony_ai.embedder.'.$name.'.model', $modelDefinition); + + $definition = new Definition(Embedder::class, [ + '$model' => new Reference('symfony_ai.embedder.'.$name.'.model'), + '$platform' => new Reference($config['platform']), + '$store' => new Reference($config['store']), + ]); + + $container->setDefinition('symfony_ai.embedder.'.$name, $definition); + } +} diff --git a/src/ai-bundle/src/DependencyInjection/Configuration.php b/src/ai-bundle/src/DependencyInjection/Configuration.php new file mode 100644 index 000000000..6a01d3fe9 --- /dev/null +++ b/src/ai-bundle/src/DependencyInjection/Configuration.php @@ -0,0 +1,212 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\AIBundle\DependencyInjection; + +use Symfony\AI\Platform\PlatformInterface; +use Symfony\AI\Store\StoreInterface; +use Symfony\Component\Config\Definition\Builder\TreeBuilder; +use Symfony\Component\Config\Definition\ConfigurationInterface; + +final class Configuration implements ConfigurationInterface +{ + public function getConfigTreeBuilder(): TreeBuilder + { + $treeBuilder = new TreeBuilder('ai'); + $rootNode = $treeBuilder->getRootNode(); + + $rootNode + ->children() + ->arrayNode('platform') + ->children() + ->arrayNode('anthropic') + ->children() + ->scalarNode('api_key')->isRequired()->end() + ->scalarNode('version')->defaultNull()->end() + ->end() + ->end() + ->arrayNode('azure') + ->normalizeKeys(false) + ->useAttributeAsKey('name') + ->arrayPrototype() + ->children() + ->scalarNode('api_key')->isRequired()->end() + ->scalarNode('base_url')->isRequired()->end() + ->scalarNode('deployment')->isRequired()->end() + ->scalarNode('api_version')->info('The used API version')->end() + ->end() + ->end() + ->end() + ->arrayNode('google') + ->children() + ->scalarNode('api_key')->isRequired()->end() + ->end() + ->end() + ->arrayNode('openai') + ->children() + ->scalarNode('api_key')->isRequired()->end() + ->end() + ->end() + ->end() + ->end() + ->arrayNode('agent') + ->normalizeKeys(false) + ->useAttributeAsKey('name') + ->arrayPrototype() + ->children() + ->scalarNode('platform') + ->info('Service name of platform') + ->defaultValue(PlatformInterface::class) + ->end() + ->arrayNode('model') + ->children() + ->scalarNode('name')->isRequired()->end() + ->scalarNode('version')->defaultNull()->end() + ->arrayNode('options') + ->scalarPrototype()->end() + ->end() + ->end() + ->end() + ->booleanNode('structured_output')->defaultTrue()->end() + ->scalarNode('system_prompt') + ->validate() + ->ifTrue(fn ($v) => null !== $v && '' === trim($v)) + ->thenInvalid('The default system prompt must not be an empty string') + ->end() + ->defaultNull() + ->info('The default system prompt of the agent') + ->end() + ->booleanNode('include_tools') + ->info('Include tool definitions at the end of the system prompt') + ->defaultFalse() + ->end() + ->arrayNode('tools') + ->addDefaultsIfNotSet() + ->treatFalseLike(['enabled' => false]) + ->treatTrueLike(['enabled' => true]) + ->treatNullLike(['enabled' => true]) + ->beforeNormalization() + ->ifArray() + ->then(function (array $v) { + return [ + 'enabled' => $v['enabled'] ?? true, + 'services' => $v['services'] ?? $v, + ]; + }) + ->end() + ->children() + ->booleanNode('enabled')->defaultTrue()->end() + ->arrayNode('services') + ->arrayPrototype() + ->children() + ->scalarNode('service')->isRequired()->end() + ->scalarNode('name')->end() + ->scalarNode('description')->end() + ->scalarNode('method')->end() + ->booleanNode('is_agent')->defaultFalse()->end() + ->end() + ->beforeNormalization() + ->ifString() + ->then(function (string $v) { + return ['service' => $v]; + }) + ->end() + ->end() + ->end() + ->end() + ->end() + ->booleanNode('fault_tolerant_toolbox')->defaultTrue()->end() + ->end() + ->end() + ->end() + ->arrayNode('store') + ->children() + ->arrayNode('azure_search') + ->normalizeKeys(false) + ->useAttributeAsKey('name') + ->arrayPrototype() + ->children() + ->scalarNode('endpoint')->isRequired()->end() + ->scalarNode('api_key')->isRequired()->end() + ->scalarNode('index_name')->isRequired()->end() + ->scalarNode('api_version')->isRequired()->end() + ->scalarNode('vector_field')->end() + ->end() + ->end() + ->end() + ->arrayNode('chroma_db') + ->normalizeKeys(false) + ->useAttributeAsKey('name') + ->arrayPrototype() + ->children() + ->scalarNode('collection')->isRequired()->end() + ->end() + ->end() + ->end() + ->arrayNode('mongodb') + ->normalizeKeys(false) + ->useAttributeAsKey('name') + ->arrayPrototype() + ->children() + ->scalarNode('database')->isRequired()->end() + ->scalarNode('collection')->isRequired()->end() + ->scalarNode('index_name')->isRequired()->end() + ->scalarNode('vector_field')->end() + ->booleanNode('bulk_write')->end() + ->end() + ->end() + ->end() + ->arrayNode('pinecone') + ->normalizeKeys(false) + ->useAttributeAsKey('name') + ->arrayPrototype() + ->children() + ->scalarNode('namespace')->end() + ->arrayNode('filter') + ->scalarPrototype()->end() + ->end() + ->integerNode('top_k')->end() + ->end() + ->end() + ->end() + ->end() + ->end() + ->arrayNode('embedder') + ->normalizeKeys(false) + ->useAttributeAsKey('name') + ->arrayPrototype() + ->children() + ->scalarNode('store') + ->info('Service name of store') + ->defaultValue(StoreInterface::class) + ->end() + ->scalarNode('platform') + ->info('Service name of platform') + ->defaultValue(PlatformInterface::class) + ->end() + ->arrayNode('model') + ->children() + ->scalarNode('name')->isRequired()->end() + ->scalarNode('version')->defaultNull()->end() + ->arrayNode('options') + ->scalarPrototype()->end() + ->end() + ->end() + ->end() + ->end() + ->end() + ->end() + ->end() + ; + + return $treeBuilder; + } +} diff --git a/src/ai-bundle/src/Profiler/DataCollector.php b/src/ai-bundle/src/Profiler/DataCollector.php new file mode 100644 index 000000000..fa23f77da --- /dev/null +++ b/src/ai-bundle/src/Profiler/DataCollector.php @@ -0,0 +1,89 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\AIBundle\Profiler; + +use Symfony\AI\Agent\Toolbox\ToolboxInterface; +use Symfony\AI\Platform\Tool\Tool; +use Symfony\Bundle\FrameworkBundle\DataCollector\AbstractDataCollector; +use Symfony\Component\DependencyInjection\Attribute\TaggedIterator; +use Symfony\Component\HttpFoundation\Request; +use Symfony\Component\HttpFoundation\Response; + +/** + * @phpstan-import-type PlatformCallData from TraceablePlatform + * @phpstan-import-type ToolCallData from TraceableToolbox + */ +final class DataCollector extends AbstractDataCollector +{ + /** + * @var TraceablePlatform[] + */ + private readonly array $platforms; + + /** + * @var TraceableToolbox[] + */ + private readonly array $toolboxes; + + /** + * @param TraceablePlatform[] $platforms + * @param TraceableToolbox[] $toolboxes + */ + public function __construct( + #[TaggedIterator('symfony_ai.traceable_platform')] + iterable $platforms, + private readonly ToolboxInterface $defaultToolBox, + #[TaggedIterator('symfony_ai.traceable_toolbox')] + iterable $toolboxes, + ) { + $this->platforms = $platforms instanceof \Traversable ? iterator_to_array($platforms) : $platforms; + $this->toolboxes = $toolboxes instanceof \Traversable ? iterator_to_array($toolboxes) : $toolboxes; + } + + public function collect(Request $request, Response $response, ?\Throwable $exception = null): void + { + $this->data = [ + 'tools' => $this->defaultToolBox->getTools(), + 'platform_calls' => array_merge(...array_map(fn (TraceablePlatform $platform) => $platform->calls, $this->platforms)), + 'tool_calls' => array_merge(...array_map(fn (TraceableToolbox $toolbox) => $toolbox->calls, $this->toolboxes)), + ]; + } + + public static function getTemplate(): string + { + return '@AI/data_collector.html.twig'; + } + + /** + * @return PlatformCallData[] + */ + public function getPlatformCalls(): array + { + return $this->data['platform_calls'] ?? []; + } + + /** + * @return Tool[] + */ + public function getTools(): array + { + return $this->data['tools'] ?? []; + } + + /** + * @return ToolCallData[] + */ + public function getToolCalls(): array + { + return $this->data['tool_calls'] ?? []; + } +} diff --git a/src/ai-bundle/src/Profiler/TraceablePlatform.php b/src/ai-bundle/src/Profiler/TraceablePlatform.php new file mode 100644 index 000000000..112cdd527 --- /dev/null +++ b/src/ai-bundle/src/Profiler/TraceablePlatform.php @@ -0,0 +1,56 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\AIBundle\Profiler; + +use Symfony\AI\Platform\Message\Content\File; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\PlatformInterface; +use Symfony\AI\Platform\Response\ResponseInterface; + +/** + * @phpstan-type PlatformCallData array{ + * model: Model, + * input: array|string|object, + * options: array, + * response: ResponseInterface, + * } + */ +final class TraceablePlatform implements PlatformInterface +{ + /** + * @var PlatformCallData[] + */ + public array $calls = []; + + public function __construct( + private readonly PlatformInterface $platform, + ) { + } + + public function request(Model $model, array|string|object $input, array $options = []): ResponseInterface + { + $response = $this->platform->request($model, $input, $options); + + if ($input instanceof File) { + $input = $input::class.': '.$input->getFormat(); + } + + $this->calls[] = [ + 'model' => $model, + 'input' => \is_object($input) ? clone $input : $input, + 'options' => $options, + 'response' => $response->getContent(), + ]; + + return $response; + } +} diff --git a/src/ai-bundle/src/Profiler/TraceableToolbox.php b/src/ai-bundle/src/Profiler/TraceableToolbox.php new file mode 100644 index 000000000..9af9b7a40 --- /dev/null +++ b/src/ai-bundle/src/Profiler/TraceableToolbox.php @@ -0,0 +1,51 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\AIBundle\Profiler; + +use Symfony\AI\Agent\Toolbox\ToolboxInterface; +use Symfony\AI\Platform\Response\ToolCall; + +/** + * @phpstan-type ToolCallData array{ + * call: ToolCall, + * result: string, + * } + */ +final class TraceableToolbox implements ToolboxInterface +{ + /** + * @var ToolCallData[] + */ + public array $calls = []; + + public function __construct( + private readonly ToolboxInterface $toolbox, + ) { + } + + public function getTools(): array + { + return $this->toolbox->getTools(); + } + + public function execute(ToolCall $toolCall): mixed + { + $result = $this->toolbox->execute($toolCall); + + $this->calls[] = [ + 'call' => $toolCall, + 'result' => $result, + ]; + + return $result; + } +} diff --git a/src/ai-bundle/src/Resources/config/services.php b/src/ai-bundle/src/Resources/config/services.php new file mode 100644 index 000000000..95563fae6 --- /dev/null +++ b/src/ai-bundle/src/Resources/config/services.php @@ -0,0 +1,76 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\Component\DependencyInjection\Loader\Configurator; + +use Symfony\AI\Agent\StructuredOutput\AgentProcessor as StructureOutputProcessor; +use Symfony\AI\Agent\StructuredOutput\ResponseFormatFactory; +use Symfony\AI\Agent\StructuredOutput\ResponseFormatFactoryInterface; +use Symfony\AI\Agent\Toolbox\AgentProcessor as ToolProcessor; +use Symfony\AI\Agent\Toolbox\Toolbox; +use Symfony\AI\Agent\Toolbox\ToolboxInterface; +use Symfony\AI\Agent\Toolbox\ToolFactory\ReflectionToolFactory; +use Symfony\AI\Agent\Toolbox\ToolFactoryInterface; +use Symfony\AI\AIBundle\Profiler\DataCollector; +use Symfony\AI\AIBundle\Profiler\TraceableToolbox; + +return static function (ContainerConfigurator $container): void { + $container->services() + ->defaults() + ->autowire() + + // structured output + ->set(ResponseFormatFactory::class) + ->alias(ResponseFormatFactoryInterface::class, ResponseFormatFactory::class) + ->set(StructureOutputProcessor::class) + ->tag('symfony_ai.agent.input_processor') + ->tag('symfony_ai.agent.output_processor') + + // tools + ->set('symfony_ai.toolbox.abstract') + ->class(Toolbox::class) + ->autowire() + ->abstract() + ->args([ + '$toolFactory' => service(ToolFactoryInterface::class), + '$tools' => abstract_arg('Collection of tools'), + ]) + ->set(Toolbox::class) + ->parent('symfony_ai.toolbox.abstract') + ->args([ + '$tools' => tagged_iterator('symfony_ai.tool'), + ]) + ->alias(ToolboxInterface::class, Toolbox::class) + ->set(ReflectionToolFactory::class) + ->alias(ToolFactoryInterface::class, ReflectionToolFactory::class) + ->set('symfony_ai.tool.agent_processor.abstract') + ->class(ToolProcessor::class) + ->abstract() + ->args([ + '$toolbox' => abstract_arg('Toolbox'), + ]) + ->set(ToolProcessor::class) + ->parent('symfony_ai.tool.agent_processor.abstract') + ->tag('symfony_ai.agent.input_processor') + ->tag('symfony_ai.agent.output_processor') + ->args([ + '$toolbox' => service(ToolboxInterface::class), + '$eventDispatcher' => service('event_dispatcher')->nullOnInvalid(), + ]) + + // profiler + ->set(DataCollector::class) + ->tag('data_collector') + ->set(TraceableToolbox::class) + ->decorate(ToolboxInterface::class) + ->tag('symfony_ai.traceable_toolbox') + ; +}; diff --git a/src/ai-bundle/src/Resources/views/data_collector.html.twig b/src/ai-bundle/src/Resources/views/data_collector.html.twig new file mode 100644 index 000000000..dc2f93f20 --- /dev/null +++ b/src/ai-bundle/src/Resources/views/data_collector.html.twig @@ -0,0 +1,252 @@ +{% extends '@WebProfiler/Profiler/layout.html.twig' %} + +{% block toolbar %} + {% if collector.platformCalls|length > 0 %} + {% set icon %} + {{ include('@AI/icon.svg', { y: 18 }) }} + {{ collector.platformCalls|length }} + + calls + + {% endset %} + + {% set text %} +
+
+ Configured Platforms + 1 +
+
+ Platform Calls + {{ collector.platformCalls|length }} +
+
+ Registered Tools + {{ collector.tools|length }} +
+
+ Tool Calls + {{ collector.toolCalls|length }} +
+
+ {% endset %} + + {{ include('@WebProfiler/Profiler/toolbar_item.html.twig', { 'link': true }) }} + {% endif %} +{% endblock %} + +{% block menu %} + + {{ include('@AI/icon.svg', { y: 16 }) }} + AI + {{ collector.platformCalls|length }} + +{% endblock %} + +{% macro tool_calls(toolCalls) %} + Tool call{{ toolCalls|length > 1 ? 's' }}: +
    + {% for toolCall in toolCalls %} +
  1. + {{ toolCall.name }}({{ toolCall.arguments|map((value, key) => "#{key}: #{value}")|join(', ') }}) + (ID: {{ toolCall.id }}) +
  2. + {% endfor %} +
+{% endmacro %} + +{% block panel %} +

AI

+
+
+
+ 1 + Platforms +
+
+ {{ collector.platformCalls|length }} + Platform Calls +
+
+
+
+
+ {{ collector.tools|length }} + Tools +
+
+ {{ collector.toolCalls|length }} + Tool Calls +
+
+
+

Platform Calls

+ {% if collector.platformCalls|length %} +
+
+

Platform Calls {{ collector.platformCalls|length }}

+
+ {% for call in collector.platformCalls %} + + + + + + + + + + + + + + + + + + + + + + + + +
Call {{ loop.index }}
Model{{ constant('class', call.model) }} (Version: {{ call.model.name }})
Input + {% if call.input.messages is defined %}{# expect MessageBag #} +
    + {% for message in call.input.messages %} +
  1. + {{ message.role.value|title }}: + {% if 'assistant' == message.role.value and message.hasToolCalls%} + {{ _self.tool_calls(message.toolCalls) }} + {% elseif 'tool' == message.role.value %} + Result of tool call with ID {{ message.toolCall.id }}
    + {{ message.content|nl2br }} + {% elseif 'user' == message.role.value %} + {% for item in message.content %} + {% if item.text is defined %} + {{ item.text|nl2br }} + {% else %} + + {% endif %} + {% endfor %} + {% else %} + {{ message.content|nl2br }} + {% endif %} +
  2. + {% endfor %} +
+ {% else %} + {{ dump(call.input) }} + {% endif %} +
Options +
    + {% for key, value in call.options %} + {% if key == 'tools' %} +
  • {{ key }}: +
      + {% for tool in value %} +
    • {{ tool.name }}
    • + {% endfor %} +
    +
  • + {% else %} +
  • {{ key }}: {{ dump(value) }}
  • + {% endif %} + {% endfor %} +
+
Response + {% if call.input.messages is defined and call.response is iterable %}{# expect array of ToolCall #} + {{ _self.tool_calls(call.response) }} + {% elseif call.response is iterable %}{# expect array of Vectors #} +
    + {% for vector in call.response %} +
  1. Vector with {{ vector.dimensions }} dimensions
  2. + {% endfor %} +
+ {% else %} + {{ call.response }} + {% endif %} +
+ {% endfor %} +
+
+
+ {% else %} +
+

No platform calls were made.

+
+ {% endif %} + +

Tools

+ {% if collector.tools|length %} + + + + + + + + + + + {% for tool in collector.tools %} + + + + + + + {% endfor %} + +
NameDescriptionClass & MethodParameters
{{ tool.name }}{{ tool.description }}{{ tool.reference.class }}::{{ tool.reference.method }} + {% if tool.parameters %} +
    + {% for name, parameter in tool.parameters.properties %} +
  • + {{ name }} ({{ parameter.type }})
    + {{ parameter.description }} +
  • + {% endfor %} +
+ {% else %} + none + {% endif %} +
+ {% else %} +
+

No tools were registered.

+
+ {% endif %} + +

Tool Calls

+ {% if collector.toolCalls|length %} + {% for call in collector.toolCalls %} + + + + + + + + + + + + + + + + + + + + +
{{ call.call.name }}
ID{{ call.call.id }}
Arguments{{ dump(call.call.arguments) }}
Response{{ call.result|nl2br }}
+ {% endfor %} + {% else %} +
+

No tool calls were made.

+
+ {% endif %} +{% endblock %} diff --git a/src/ai-bundle/src/Resources/views/icon.svg b/src/ai-bundle/src/Resources/views/icon.svg new file mode 100644 index 000000000..56f1a6c08 --- /dev/null +++ b/src/ai-bundle/src/Resources/views/icon.svg @@ -0,0 +1,16 @@ + + + + + + LLM + diff --git a/src/ai-bundle/tests/Profiler/TraceableToolboxTest.php b/src/ai-bundle/tests/Profiler/TraceableToolboxTest.php new file mode 100644 index 000000000..6990ca891 --- /dev/null +++ b/src/ai-bundle/tests/Profiler/TraceableToolboxTest.php @@ -0,0 +1,78 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\AIBundle\Tests\Profiler; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Agent\Toolbox\ToolboxInterface; +use Symfony\AI\AIBundle\Profiler\TraceableToolbox; +use Symfony\AI\Platform\Response\ToolCall; +use Symfony\AI\Platform\Tool\ExecutionReference; +use Symfony\AI\Platform\Tool\Tool; + +#[CoversClass(TraceableToolbox::class)] +#[Small] +final class TraceableToolboxTest extends TestCase +{ + #[Test] + public function getMap(): void + { + $metadata = new Tool(new ExecutionReference('Foo\Bar'), 'bar', 'description', null); + $toolbox = $this->createToolbox(['tool' => $metadata]); + $traceableToolbox = new TraceableToolbox($toolbox); + + $map = $traceableToolbox->getTools(); + + self::assertSame(['tool' => $metadata], $map); + } + + #[Test] + public function execute(): void + { + $metadata = new Tool(new ExecutionReference('Foo\Bar'), 'bar', 'description', null); + $toolbox = $this->createToolbox(['tool' => $metadata]); + $traceableToolbox = new TraceableToolbox($toolbox); + $toolCall = new ToolCall('foo', '__invoke'); + + $result = $traceableToolbox->execute($toolCall); + + self::assertSame('tool_result', $result); + self::assertCount(1, $traceableToolbox->calls); + self::assertSame($toolCall, $traceableToolbox->calls[0]['call']); + self::assertSame('tool_result', $traceableToolbox->calls[0]['result']); + } + + /** + * @param Tool[] $tools + */ + private function createToolbox(array $tools): ToolboxInterface + { + return new class($tools) implements ToolboxInterface { + public function __construct( + private readonly array $tools, + ) { + } + + public function getTools(): array + { + return $this->tools; + } + + public function execute(ToolCall $toolCall): string + { + return 'tool_result'; + } + }; + } +} diff --git a/src/platform/.gitattributes b/src/platform/.gitattributes new file mode 100644 index 000000000..ec8c01802 --- /dev/null +++ b/src/platform/.gitattributes @@ -0,0 +1,6 @@ +/.github export-ignore +/tests export-ignore +.gitattributes export-ignore +.gitignore export-ignore +phpstan.dist.neon export-ignore +phpunit.xml.dist export-ignore diff --git a/src/platform/.github/PULL_REQUEST_TEMPLATE.md b/src/platform/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 000000000..fcb87228a --- /dev/null +++ b/src/platform/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,8 @@ +Please do not submit any Pull Requests here. They will be closed. +--- + +Please submit your PR here instead: +https://github.com/symfony/ai + +This repository is what we call a "subtree split": a read-only subset of that main repository. +We're looking forward to your PR there! diff --git a/src/platform/.github/workflows/close-pull-request.yml b/src/platform/.github/workflows/close-pull-request.yml new file mode 100644 index 000000000..207153fd5 --- /dev/null +++ b/src/platform/.github/workflows/close-pull-request.yml @@ -0,0 +1,20 @@ +name: Close Pull Request + +on: + pull_request_target: + types: [opened] + +jobs: + run: + runs-on: ubuntu-latest + steps: + - uses: superbrothers/close-pull-request@v3 + with: + comment: | + Thanks for your Pull Request! We love contributions. + + However, you should instead open your PR on the main repository: + https://github.com/symfony/ai + + This repository is what we call a "subtree split": a read-only subset of that main repository. + We're looking forward to your PR there! diff --git a/src/platform/.gitignore b/src/platform/.gitignore new file mode 100644 index 000000000..f43db636b --- /dev/null +++ b/src/platform/.gitignore @@ -0,0 +1,3 @@ +composer.lock +vendor +.phpunit.cache diff --git a/src/platform/LICENSE b/src/platform/LICENSE new file mode 100644 index 000000000..bc38d714e --- /dev/null +++ b/src/platform/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2025-present Fabien Potencier + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is furnished +to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/src/platform/composer.json b/src/platform/composer.json new file mode 100644 index 000000000..02912e405 --- /dev/null +++ b/src/platform/composer.json @@ -0,0 +1,74 @@ +{ + "name": "symfony/ai-platform", + "type": "library", + "description": "PHP library for interacting with AI platform provider.", + "keywords": [ + "ai", + "huggingface", + "transformers", + "inference" + ], + "license": "MIT", + "authors": [ + { + "name": "Christopher Hertel", + "email": "mail@christopher-hertel.de" + }, + { + "name": "Oskar Stark", + "email": "oskarstark@googlemail.com" + } + ], + "require": { + "php": ">=8.2", + "ext-fileinfo": "*", + "oskarstark/enum-helper": "^1.5", + "phpdocumentor/reflection-docblock": "^5.4", + "phpstan/phpdoc-parser": "^2.1", + "psr/cache": "^3.0", + "psr/log": "^3.0", + "symfony/clock": "^6.4 || ^7.1", + "symfony/http-client": "^6.4 || ^7.1", + "symfony/property-access": "^6.4 || ^7.1", + "symfony/property-info": "^6.4 || ^7.1", + "symfony/serializer": "^6.4 || ^7.1", + "symfony/type-info": "^7.2.3", + "symfony/uid": "^6.4 || ^7.1", + "webmozart/assert": "^1.11" + }, + "require-dev": { + "codewithkyrian/transformers": "^0.5.3", + "async-aws/bedrock-runtime": "^0.1.0", + "phpstan/phpstan": "^2.0", + "phpstan/phpstan-symfony": "^2.0", + "phpstan/phpstan-webmozart-assert": "^2.0", + "phpunit/phpunit": "^11.5", + "rector/rector": "^2.0", + "symfony/console": "^6.4 || ^7.1", + "symfony/dotenv": "^6.4 || ^7.1", + "symfony/event-dispatcher": "^6.4 || ^7.1", + "symfony/finder": "^6.4 || ^7.1", + "symfony/process": "^6.4 || ^7.1", + "symfony/var-dumper": "^6.4 || ^7.1" + }, + "suggest": { + "async-aws/bedrock-runtime": "For using the Bedrock platform.", + "codewithkyrian/transformers": "For using the TransformersPHP with FFI to run models in PHP." + }, + "config": { + "allow-plugins": { + "codewithkyrian/transformers-libsloader": true + }, + "sort-packages": true + }, + "autoload": { + "psr-4": { + "Symfony\\AI\\Platform\\": "src/" + } + }, + "autoload-dev": { + "psr-4": { + "Symfony\\AI\\Platform\\Tests\\": "tests/" + } + } +} diff --git a/src/platform/phpstan.dist.neon b/src/platform/phpstan.dist.neon new file mode 100644 index 000000000..8cc83f644 --- /dev/null +++ b/src/platform/phpstan.dist.neon @@ -0,0 +1,10 @@ +includes: + - vendor/phpstan/phpstan-webmozart-assert/extension.neon + - vendor/phpstan/phpstan-symfony/extension.neon + +parameters: + level: 6 + paths: + - src/ + - tests/ + diff --git a/src/platform/phpunit.xml.dist b/src/platform/phpunit.xml.dist new file mode 100644 index 000000000..4e9e3a684 --- /dev/null +++ b/src/platform/phpunit.xml.dist @@ -0,0 +1,24 @@ + + + + + tests + + + + + + src + + + diff --git a/src/platform/src/Bridge/Anthropic/Claude.php b/src/platform/src/Bridge/Anthropic/Claude.php new file mode 100644 index 000000000..539d6fbc8 --- /dev/null +++ b/src/platform/src/Bridge/Anthropic/Claude.php @@ -0,0 +1,47 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Anthropic; + +use Symfony\AI\Platform\Capability; +use Symfony\AI\Platform\Model; + +/** + * @author Christopher Hertel + */ +class Claude extends Model +{ + public const HAIKU_3 = 'claude-3-haiku-20240307'; + public const HAIKU_35 = 'claude-3-5-haiku-20241022'; + public const SONNET_3 = 'claude-3-sonnet-20240229'; + public const SONNET_35 = 'claude-3-5-sonnet-20240620'; + public const SONNET_35_V2 = 'claude-3-5-sonnet-20241022'; + public const SONNET_37 = 'claude-3-7-sonnet-20250219'; + public const OPUS_3 = 'claude-3-opus-20240229'; + + /** + * @param array $options The default options for the model usage + */ + public function __construct( + string $name = self::SONNET_37, + array $options = ['temperature' => 1.0, 'max_tokens' => 1000], + ) { + $capabilities = [ + Capability::INPUT_MESSAGES, + Capability::INPUT_IMAGE, + Capability::OUTPUT_TEXT, + Capability::OUTPUT_STREAMING, + Capability::TOOL_CALLING, + ]; + + parent::__construct($name, $capabilities, $options); + } +} diff --git a/src/platform/src/Bridge/Anthropic/Contract/AssistantMessageNormalizer.php b/src/platform/src/Bridge/Anthropic/Contract/AssistantMessageNormalizer.php new file mode 100644 index 000000000..d6a4c97d1 --- /dev/null +++ b/src/platform/src/Bridge/Anthropic/Contract/AssistantMessageNormalizer.php @@ -0,0 +1,66 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Anthropic\Contract; + +use Symfony\AI\Platform\Bridge\Anthropic\Claude; +use Symfony\AI\Platform\Contract\Normalizer\ModelContractNormalizer; +use Symfony\AI\Platform\Message\AssistantMessage; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\Response\ToolCall; +use Symfony\Component\Serializer\Normalizer\NormalizerAwareInterface; +use Symfony\Component\Serializer\Normalizer\NormalizerAwareTrait; + +/** + * @author Christopher Hertel + */ +final class AssistantMessageNormalizer extends ModelContractNormalizer implements NormalizerAwareInterface +{ + use NormalizerAwareTrait; + + protected function supportedDataClass(): string + { + return AssistantMessage::class; + } + + protected function supportsModel(Model $model): bool + { + return $model instanceof Claude; + } + + /** + * @param AssistantMessage $data + * + * @return array{ + * role: 'assistant', + * content: list + * }> + * } + */ + public function normalize(mixed $data, ?string $format = null, array $context = []): array + { + return [ + 'role' => 'assistant', + 'content' => array_map(static function (ToolCall $toolCall) { + return [ + 'type' => 'tool_use', + 'id' => $toolCall->id, + 'name' => $toolCall->name, + 'input' => empty($toolCall->arguments) ? new \stdClass() : $toolCall->arguments, + ]; + }, $data->toolCalls), + ]; + } +} diff --git a/src/platform/src/Bridge/Anthropic/Contract/DocumentNormalizer.php b/src/platform/src/Bridge/Anthropic/Contract/DocumentNormalizer.php new file mode 100644 index 000000000..2ac4e58a4 --- /dev/null +++ b/src/platform/src/Bridge/Anthropic/Contract/DocumentNormalizer.php @@ -0,0 +1,50 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Anthropic\Contract; + +use Symfony\AI\Platform\Bridge\Anthropic\Claude; +use Symfony\AI\Platform\Contract\Normalizer\ModelContractNormalizer; +use Symfony\AI\Platform\Message\Content\Document; +use Symfony\AI\Platform\Model; + +/** + * @author Christopher Hertel + */ +class DocumentNormalizer extends ModelContractNormalizer +{ + protected function supportedDataClass(): string + { + return Document::class; + } + + protected function supportsModel(Model $model): bool + { + return $model instanceof Claude; + } + + /** + * @param Document $data + * + * @return array{type: 'document', source: array{type: 'base64', media_type: string, data: string}} + */ + public function normalize(mixed $data, ?string $format = null, array $context = []): array + { + return [ + 'type' => 'document', + 'source' => [ + 'type' => 'base64', + 'media_type' => $data->getFormat(), + 'data' => $data->asBase64(), + ], + ]; + } +} diff --git a/src/platform/src/Bridge/Anthropic/Contract/DocumentUrlNormalizer.php b/src/platform/src/Bridge/Anthropic/Contract/DocumentUrlNormalizer.php new file mode 100644 index 000000000..662e0ea7e --- /dev/null +++ b/src/platform/src/Bridge/Anthropic/Contract/DocumentUrlNormalizer.php @@ -0,0 +1,49 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Anthropic\Contract; + +use Symfony\AI\Platform\Bridge\Anthropic\Claude; +use Symfony\AI\Platform\Contract\Normalizer\ModelContractNormalizer; +use Symfony\AI\Platform\Message\Content\DocumentUrl; +use Symfony\AI\Platform\Model; + +/** + * @author Christopher Hertel + */ +final class DocumentUrlNormalizer extends ModelContractNormalizer +{ + protected function supportedDataClass(): string + { + return DocumentUrl::class; + } + + protected function supportsModel(Model $model): bool + { + return $model instanceof Claude; + } + + /** + * @param DocumentUrl $data + * + * @return array{type: 'document', source: array{type: 'url', url: string}} + */ + public function normalize(mixed $data, ?string $format = null, array $context = []): array + { + return [ + 'type' => 'document', + 'source' => [ + 'type' => 'url', + 'url' => $data->url, + ], + ]; + } +} diff --git a/src/platform/src/Bridge/Anthropic/Contract/ImageNormalizer.php b/src/platform/src/Bridge/Anthropic/Contract/ImageNormalizer.php new file mode 100644 index 000000000..a165189fe --- /dev/null +++ b/src/platform/src/Bridge/Anthropic/Contract/ImageNormalizer.php @@ -0,0 +1,52 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Anthropic\Contract; + +use Symfony\AI\Platform\Bridge\Anthropic\Claude; +use Symfony\AI\Platform\Contract\Normalizer\ModelContractNormalizer; +use Symfony\AI\Platform\Message\Content\Image; +use Symfony\AI\Platform\Model; + +use function Symfony\Component\String\u; + +/** + * @author Christopher Hertel + */ +final class ImageNormalizer extends ModelContractNormalizer +{ + protected function supportedDataClass(): string + { + return Image::class; + } + + protected function supportsModel(Model $model): bool + { + return $model instanceof Claude; + } + + /** + * @param Image $data + * + * @return array{type: 'image', source: array{type: 'base64', media_type: string, data: string}} + */ + public function normalize(mixed $data, ?string $format = null, array $context = []): array + { + return [ + 'type' => 'image', + 'source' => [ + 'type' => 'base64', + 'media_type' => u($data->getFormat())->replace('jpg', 'jpeg')->toString(), + 'data' => $data->asBase64(), + ], + ]; + } +} diff --git a/src/platform/src/Bridge/Anthropic/Contract/ImageUrlNormalizer.php b/src/platform/src/Bridge/Anthropic/Contract/ImageUrlNormalizer.php new file mode 100644 index 000000000..e0f422fa8 --- /dev/null +++ b/src/platform/src/Bridge/Anthropic/Contract/ImageUrlNormalizer.php @@ -0,0 +1,50 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Anthropic\Contract; + +use Symfony\AI\Platform\Bridge\Anthropic\Claude; +use Symfony\AI\Platform\Contract\Normalizer\ModelContractNormalizer; +use Symfony\AI\Platform\Message\Content\Image; +use Symfony\AI\Platform\Message\Content\ImageUrl; +use Symfony\AI\Platform\Model; + +/** + * @author Christopher Hertel + */ +final class ImageUrlNormalizer extends ModelContractNormalizer +{ + protected function supportedDataClass(): string + { + return ImageUrl::class; + } + + protected function supportsModel(Model $model): bool + { + return $model instanceof Claude; + } + + /** + * @param ImageUrl $data + * + * @return array{type: 'image', source: array{type: 'url', url: string}} + */ + public function normalize(mixed $data, ?string $format = null, array $context = []): array + { + return [ + 'type' => 'image', + 'source' => [ + 'type' => 'url', + 'url' => $data->url, + ], + ]; + } +} diff --git a/src/platform/src/Bridge/Anthropic/Contract/MessageBagNormalizer.php b/src/platform/src/Bridge/Anthropic/Contract/MessageBagNormalizer.php new file mode 100644 index 000000000..669e77a66 --- /dev/null +++ b/src/platform/src/Bridge/Anthropic/Contract/MessageBagNormalizer.php @@ -0,0 +1,64 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Anthropic\Contract; + +use Symfony\AI\Platform\Bridge\Anthropic\Claude; +use Symfony\AI\Platform\Contract; +use Symfony\AI\Platform\Contract\Normalizer\ModelContractNormalizer; +use Symfony\AI\Platform\Message\MessageBagInterface; +use Symfony\AI\Platform\Model; +use Symfony\Component\Serializer\Normalizer\NormalizerAwareInterface; +use Symfony\Component\Serializer\Normalizer\NormalizerAwareTrait; + +/** + * @author Christopher Hertel + */ +final class MessageBagNormalizer extends ModelContractNormalizer implements NormalizerAwareInterface +{ + use NormalizerAwareTrait; + + protected function supportedDataClass(): string + { + return MessageBagInterface::class; + } + + protected function supportsModel(Model $model): bool + { + return $model instanceof Claude; + } + + /** + * @param MessageBagInterface $data + * + * @return array{ + * messages: array, + * model?: string, + * system?: string, + * } + */ + public function normalize(mixed $data, ?string $format = null, array $context = []): array + { + $array = [ + 'messages' => $this->normalizer->normalize($data->withoutSystemMessage()->getMessages(), $format, $context), + ]; + + if (null !== $system = $data->getSystemMessage()) { + $array['system'] = $system->content; + } + + if (isset($context[Contract::CONTEXT_MODEL]) && $context[Contract::CONTEXT_MODEL] instanceof Model) { + $array['model'] = $context[Contract::CONTEXT_MODEL]->getName(); + } + + return $array; + } +} diff --git a/src/platform/src/Bridge/Anthropic/Contract/ToolCallMessageNormalizer.php b/src/platform/src/Bridge/Anthropic/Contract/ToolCallMessageNormalizer.php new file mode 100644 index 000000000..f2023296f --- /dev/null +++ b/src/platform/src/Bridge/Anthropic/Contract/ToolCallMessageNormalizer.php @@ -0,0 +1,63 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Anthropic\Contract; + +use Symfony\AI\Platform\Bridge\Anthropic\Claude; +use Symfony\AI\Platform\Contract\Normalizer\ModelContractNormalizer; +use Symfony\AI\Platform\Message\ToolCallMessage; +use Symfony\AI\Platform\Model; +use Symfony\Component\Serializer\Normalizer\NormalizerAwareInterface; +use Symfony\Component\Serializer\Normalizer\NormalizerAwareTrait; + +/** + * @author Christopher Hertel + */ +final class ToolCallMessageNormalizer extends ModelContractNormalizer implements NormalizerAwareInterface +{ + use NormalizerAwareTrait; + + protected function supportedDataClass(): string + { + return ToolCallMessage::class; + } + + protected function supportsModel(Model $model): bool + { + return $model instanceof Claude; + } + + /** + * @param ToolCallMessage $data + * + * @return array{ + * role: 'user', + * content: list + * } + */ + public function normalize(mixed $data, ?string $format = null, array $context = []): array + { + return [ + 'role' => 'user', + 'content' => [ + [ + 'type' => 'tool_result', + 'tool_use_id' => $data->toolCall->id, + 'content' => $data->content, + ], + ], + ]; + } +} diff --git a/src/platform/src/Bridge/Anthropic/Contract/ToolNormalizer.php b/src/platform/src/Bridge/Anthropic/Contract/ToolNormalizer.php new file mode 100644 index 000000000..716d3a771 --- /dev/null +++ b/src/platform/src/Bridge/Anthropic/Contract/ToolNormalizer.php @@ -0,0 +1,54 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Anthropic\Contract; + +use Symfony\AI\Platform\Bridge\Anthropic\Claude; +use Symfony\AI\Platform\Contract\JsonSchema\Factory; +use Symfony\AI\Platform\Contract\Normalizer\ModelContractNormalizer; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\Tool\Tool; + +/** + * @phpstan-import-type JsonSchema from Factory + * + * @author Christopher Hertel + */ +class ToolNormalizer extends ModelContractNormalizer +{ + protected function supportedDataClass(): string + { + return Tool::class; + } + + protected function supportsModel(Model $model): bool + { + return $model instanceof Claude; + } + + /** + * @param Tool $data + * + * @return array{ + * name: string, + * description: string, + * input_schema: JsonSchema|array{type: 'object'} + * } + */ + public function normalize(mixed $data, ?string $format = null, array $context = []): array + { + return [ + 'name' => $data->name, + 'description' => $data->description, + 'input_schema' => $data->parameters ?? ['type' => 'object'], + ]; + } +} diff --git a/src/platform/src/Bridge/Anthropic/ModelClient.php b/src/platform/src/Bridge/Anthropic/ModelClient.php new file mode 100644 index 000000000..d5c981a15 --- /dev/null +++ b/src/platform/src/Bridge/Anthropic/ModelClient.php @@ -0,0 +1,54 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Anthropic; + +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\ModelClientInterface; +use Symfony\Component\HttpClient\EventSourceHttpClient; +use Symfony\Contracts\HttpClient\HttpClientInterface; +use Symfony\Contracts\HttpClient\ResponseInterface; + +/** + * @author Christopher Hertel + */ +final readonly class ModelClient implements ModelClientInterface +{ + private EventSourceHttpClient $httpClient; + + public function __construct( + HttpClientInterface $httpClient, + #[\SensitiveParameter] private string $apiKey, + private string $version = '2023-06-01', + ) { + $this->httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient); + } + + public function supports(Model $model): bool + { + return $model instanceof Claude; + } + + public function request(Model $model, array|string $payload, array $options = []): ResponseInterface + { + if (isset($options['tools'])) { + $options['tool_choice'] = ['type' => 'auto']; + } + + return $this->httpClient->request('POST', 'https://api.anthropic.com/v1/messages', [ + 'headers' => [ + 'x-api-key' => $this->apiKey, + 'anthropic-version' => $this->version, + ], + 'json' => array_merge($options, $payload), + ]); + } +} diff --git a/src/platform/src/Bridge/Anthropic/PlatformFactory.php b/src/platform/src/Bridge/Anthropic/PlatformFactory.php new file mode 100644 index 000000000..5a674ca59 --- /dev/null +++ b/src/platform/src/Bridge/Anthropic/PlatformFactory.php @@ -0,0 +1,55 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Anthropic; + +use Symfony\AI\Platform\Bridge\Anthropic\Contract\AssistantMessageNormalizer; +use Symfony\AI\Platform\Bridge\Anthropic\Contract\DocumentNormalizer; +use Symfony\AI\Platform\Bridge\Anthropic\Contract\DocumentUrlNormalizer; +use Symfony\AI\Platform\Bridge\Anthropic\Contract\ImageNormalizer; +use Symfony\AI\Platform\Bridge\Anthropic\Contract\ImageUrlNormalizer; +use Symfony\AI\Platform\Bridge\Anthropic\Contract\MessageBagNormalizer; +use Symfony\AI\Platform\Bridge\Anthropic\Contract\ToolCallMessageNormalizer; +use Symfony\AI\Platform\Bridge\Anthropic\Contract\ToolNormalizer; +use Symfony\AI\Platform\Contract; +use Symfony\AI\Platform\Platform; +use Symfony\Component\HttpClient\EventSourceHttpClient; +use Symfony\Contracts\HttpClient\HttpClientInterface; + +/** + * @author Christopher Hertel + */ +final readonly class PlatformFactory +{ + public static function create( + #[\SensitiveParameter] + string $apiKey, + string $version = '2023-06-01', + ?HttpClientInterface $httpClient = null, + ): Platform { + $httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient); + + return new Platform( + [new ModelClient($httpClient, $apiKey, $version)], + [new ResponseConverter()], + Contract::create( + new AssistantMessageNormalizer(), + new DocumentNormalizer(), + new DocumentUrlNormalizer(), + new ImageNormalizer(), + new ImageUrlNormalizer(), + new MessageBagNormalizer(), + new ToolCallMessageNormalizer(), + new ToolNormalizer(), + ) + ); + } +} diff --git a/src/platform/src/Bridge/Anthropic/ResponseConverter.php b/src/platform/src/Bridge/Anthropic/ResponseConverter.php new file mode 100644 index 000000000..805baf34a --- /dev/null +++ b/src/platform/src/Bridge/Anthropic/ResponseConverter.php @@ -0,0 +1,88 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Anthropic; + +use Symfony\AI\Platform\Exception\RuntimeException; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\Response\ResponseInterface as LlmResponse; +use Symfony\AI\Platform\Response\StreamResponse; +use Symfony\AI\Platform\Response\TextResponse; +use Symfony\AI\Platform\Response\ToolCall; +use Symfony\AI\Platform\Response\ToolCallResponse; +use Symfony\AI\Platform\ResponseConverterInterface; +use Symfony\Component\HttpClient\Chunk\ServerSentEvent; +use Symfony\Component\HttpClient\EventSourceHttpClient; +use Symfony\Component\HttpClient\Exception\JsonException; +use Symfony\Contracts\HttpClient\ResponseInterface; + +/** + * @author Christopher Hertel + */ +class ResponseConverter implements ResponseConverterInterface +{ + public function supports(Model $model): bool + { + return $model instanceof Claude; + } + + public function convert(ResponseInterface $response, array $options = []): LlmResponse + { + if ($options['stream'] ?? false) { + return new StreamResponse($this->convertStream($response)); + } + + $data = $response->toArray(); + + if (!isset($data['content']) || 0 === \count($data['content'])) { + throw new RuntimeException('Response does not contain any content'); + } + + $toolCalls = []; + foreach ($data['content'] as $content) { + if ('tool_use' === $content['type']) { + $toolCalls[] = new ToolCall($content['id'], $content['name'], $content['input']); + } + } + + if (!isset($data['content'][0]['text']) && 0 === \count($toolCalls)) { + throw new RuntimeException('Response content does not contain any text nor tool calls.'); + } + + if (!empty($toolCalls)) { + return new ToolCallResponse(...$toolCalls); + } + + return new TextResponse($data['content'][0]['text']); + } + + private function convertStream(ResponseInterface $response): \Generator + { + foreach ((new EventSourceHttpClient())->stream($response) as $chunk) { + if (!$chunk instanceof ServerSentEvent || '[DONE]' === $chunk->getData()) { + continue; + } + + try { + $data = $chunk->getArrayData(); + } catch (JsonException) { + // try catch only needed for Symfony 6.4 + continue; + } + + if ('content_block_delta' != $data['type'] || !isset($data['delta']['text'])) { + continue; + } + + yield $data['delta']['text']; + } + } +} diff --git a/src/platform/src/Bridge/Azure/Meta/LlamaHandler.php b/src/platform/src/Bridge/Azure/Meta/LlamaHandler.php new file mode 100644 index 000000000..40278c464 --- /dev/null +++ b/src/platform/src/Bridge/Azure/Meta/LlamaHandler.php @@ -0,0 +1,64 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Azure\Meta; + +use Symfony\AI\Platform\Bridge\Meta\Llama; +use Symfony\AI\Platform\Exception\RuntimeException; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\ModelClientInterface; +use Symfony\AI\Platform\Response\ResponseInterface as LlmResponse; +use Symfony\AI\Platform\Response\TextResponse; +use Symfony\AI\Platform\ResponseConverterInterface; +use Symfony\Contracts\HttpClient\HttpClientInterface; +use Symfony\Contracts\HttpClient\ResponseInterface; + +/** + * @author Christopher Hertel + */ +final readonly class LlamaHandler implements ModelClientInterface, ResponseConverterInterface +{ + public function __construct( + private HttpClientInterface $httpClient, + private string $baseUrl, + #[\SensitiveParameter] private string $apiKey, + ) { + } + + public function supports(Model $model): bool + { + return $model instanceof Llama; + } + + public function request(Model $model, array|string $payload, array $options = []): ResponseInterface + { + $url = \sprintf('https://%s/chat/completions', $this->baseUrl); + + return $this->httpClient->request('POST', $url, [ + 'headers' => [ + 'Content-Type' => 'application/json', + 'Authorization' => $this->apiKey, + ], + 'json' => array_merge($options, $payload), + ]); + } + + public function convert(ResponseInterface $response, array $options = []): LlmResponse + { + $data = $response->toArray(); + + if (!isset($data['choices'][0]['message']['content'])) { + throw new RuntimeException('Response does not contain output'); + } + + return new TextResponse($data['choices'][0]['message']['content']); + } +} diff --git a/src/platform/src/Bridge/Azure/Meta/PlatformFactory.php b/src/platform/src/Bridge/Azure/Meta/PlatformFactory.php new file mode 100644 index 000000000..59b8b5e64 --- /dev/null +++ b/src/platform/src/Bridge/Azure/Meta/PlatformFactory.php @@ -0,0 +1,33 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Azure\Meta; + +use Symfony\AI\Platform\Platform; +use Symfony\Component\HttpClient\HttpClient; +use Symfony\Contracts\HttpClient\HttpClientInterface; + +/** + * @author Christopher Hertel + */ +final readonly class PlatformFactory +{ + public static function create( + string $baseUrl, + #[\SensitiveParameter] + string $apiKey, + ?HttpClientInterface $httpClient = null, + ): Platform { + $modelClient = new LlamaHandler($httpClient ?? HttpClient::create(), $baseUrl, $apiKey); + + return new Platform([$modelClient], [$modelClient]); + } +} diff --git a/src/platform/src/Bridge/Azure/OpenAI/EmbeddingsModelClient.php b/src/platform/src/Bridge/Azure/OpenAI/EmbeddingsModelClient.php new file mode 100644 index 000000000..a45dd0661 --- /dev/null +++ b/src/platform/src/Bridge/Azure/OpenAI/EmbeddingsModelClient.php @@ -0,0 +1,64 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Azure\OpenAI; + +use Symfony\AI\Platform\Bridge\OpenAI\Embeddings; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\ModelClientInterface; +use Symfony\Component\HttpClient\EventSourceHttpClient; +use Symfony\Contracts\HttpClient\HttpClientInterface; +use Symfony\Contracts\HttpClient\ResponseInterface; +use Webmozart\Assert\Assert; + +/** + * @author Christopher Hertel + */ +final readonly class EmbeddingsModelClient implements ModelClientInterface +{ + private EventSourceHttpClient $httpClient; + + public function __construct( + HttpClientInterface $httpClient, + private string $baseUrl, + private string $deployment, + private string $apiVersion, + #[\SensitiveParameter] private string $apiKey, + ) { + $this->httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient); + Assert::notStartsWith($baseUrl, 'http://', 'The base URL must not contain the protocol.'); + Assert::notStartsWith($baseUrl, 'https://', 'The base URL must not contain the protocol.'); + Assert::stringNotEmpty($deployment, 'The deployment must not be empty.'); + Assert::stringNotEmpty($apiVersion, 'The API version must not be empty.'); + Assert::stringNotEmpty($apiKey, 'The API key must not be empty.'); + } + + public function supports(Model $model): bool + { + return $model instanceof Embeddings; + } + + public function request(Model $model, array|string $payload, array $options = []): ResponseInterface + { + $url = \sprintf('https://%s/openai/deployments/%s/embeddings', $this->baseUrl, $this->deployment); + + return $this->httpClient->request('POST', $url, [ + 'headers' => [ + 'api-key' => $this->apiKey, + ], + 'query' => ['api-version' => $this->apiVersion], + 'json' => array_merge($options, [ + 'model' => $model->getName(), + 'input' => $payload, + ]), + ]); + } +} diff --git a/src/platform/src/Bridge/Azure/OpenAI/GPTModelClient.php b/src/platform/src/Bridge/Azure/OpenAI/GPTModelClient.php new file mode 100644 index 000000000..ac7f9a216 --- /dev/null +++ b/src/platform/src/Bridge/Azure/OpenAI/GPTModelClient.php @@ -0,0 +1,61 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Azure\OpenAI; + +use Symfony\AI\Platform\Bridge\OpenAI\GPT; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\ModelClientInterface; +use Symfony\Component\HttpClient\EventSourceHttpClient; +use Symfony\Contracts\HttpClient\HttpClientInterface; +use Symfony\Contracts\HttpClient\ResponseInterface; +use Webmozart\Assert\Assert; + +/** + * @author Christopher Hertel + */ +final readonly class GPTModelClient implements ModelClientInterface +{ + private EventSourceHttpClient $httpClient; + + public function __construct( + HttpClientInterface $httpClient, + private string $baseUrl, + private string $deployment, + private string $apiVersion, + #[\SensitiveParameter] private string $apiKey, + ) { + $this->httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient); + Assert::notStartsWith($baseUrl, 'http://', 'The base URL must not contain the protocol.'); + Assert::notStartsWith($baseUrl, 'https://', 'The base URL must not contain the protocol.'); + Assert::stringNotEmpty($deployment, 'The deployment must not be empty.'); + Assert::stringNotEmpty($apiVersion, 'The API version must not be empty.'); + Assert::stringNotEmpty($apiKey, 'The API key must not be empty.'); + } + + public function supports(Model $model): bool + { + return $model instanceof GPT; + } + + public function request(Model $model, object|array|string $payload, array $options = []): ResponseInterface + { + $url = \sprintf('https://%s/openai/deployments/%s/chat/completions', $this->baseUrl, $this->deployment); + + return $this->httpClient->request('POST', $url, [ + 'headers' => [ + 'api-key' => $this->apiKey, + ], + 'query' => ['api-version' => $this->apiVersion], + 'json' => array_merge($options, $payload), + ]); + } +} diff --git a/src/platform/src/Bridge/Azure/OpenAI/PlatformFactory.php b/src/platform/src/Bridge/Azure/OpenAI/PlatformFactory.php new file mode 100644 index 000000000..fd4af9489 --- /dev/null +++ b/src/platform/src/Bridge/Azure/OpenAI/PlatformFactory.php @@ -0,0 +1,46 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Azure\OpenAI; + +use Symfony\AI\Platform\Bridge\OpenAI\Embeddings; +use Symfony\AI\Platform\Bridge\OpenAI\GPT\ResponseConverter; +use Symfony\AI\Platform\Bridge\OpenAI\Whisper\AudioNormalizer; +use Symfony\AI\Platform\Contract; +use Symfony\AI\Platform\Platform; +use Symfony\Component\HttpClient\EventSourceHttpClient; +use Symfony\Contracts\HttpClient\HttpClientInterface; + +/** + * @author Christopher Hertel + */ +final readonly class PlatformFactory +{ + public static function create( + string $baseUrl, + string $deployment, + string $apiVersion, + #[\SensitiveParameter] + string $apiKey, + ?HttpClientInterface $httpClient = null, + ): Platform { + $httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient); + $embeddingsResponseFactory = new EmbeddingsModelClient($httpClient, $baseUrl, $deployment, $apiVersion, $apiKey); + $GPTResponseFactory = new GPTModelClient($httpClient, $baseUrl, $deployment, $apiVersion, $apiKey); + $whisperResponseFactory = new WhisperModelClient($httpClient, $baseUrl, $deployment, $apiVersion, $apiKey); + + return new Platform( + [$GPTResponseFactory, $embeddingsResponseFactory, $whisperResponseFactory], + [new ResponseConverter(), new Embeddings\ResponseConverter(), new \Symfony\AI\Platform\Bridge\OpenAI\Whisper\ResponseConverter()], + Contract::create(new AudioNormalizer()), + ); + } +} diff --git a/src/platform/src/Bridge/Azure/OpenAI/WhisperModelClient.php b/src/platform/src/Bridge/Azure/OpenAI/WhisperModelClient.php new file mode 100644 index 000000000..a6b7c5d2f --- /dev/null +++ b/src/platform/src/Bridge/Azure/OpenAI/WhisperModelClient.php @@ -0,0 +1,62 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Azure\OpenAI; + +use Symfony\AI\Platform\Bridge\OpenAI\Whisper; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\ModelClientInterface; +use Symfony\Component\HttpClient\EventSourceHttpClient; +use Symfony\Contracts\HttpClient\HttpClientInterface; +use Symfony\Contracts\HttpClient\ResponseInterface; +use Webmozart\Assert\Assert; + +/** + * @author Christopher Hertel + */ +final readonly class WhisperModelClient implements ModelClientInterface +{ + private EventSourceHttpClient $httpClient; + + public function __construct( + HttpClientInterface $httpClient, + private string $baseUrl, + private string $deployment, + private string $apiVersion, + #[\SensitiveParameter] private string $apiKey, + ) { + $this->httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient); + Assert::notStartsWith($baseUrl, 'http://', 'The base URL must not contain the protocol.'); + Assert::notStartsWith($baseUrl, 'https://', 'The base URL must not contain the protocol.'); + Assert::stringNotEmpty($deployment, 'The deployment must not be empty.'); + Assert::stringNotEmpty($apiVersion, 'The API version must not be empty.'); + Assert::stringNotEmpty($apiKey, 'The API key must not be empty.'); + } + + public function supports(Model $model): bool + { + return $model instanceof Whisper; + } + + public function request(Model $model, array|string $payload, array $options = []): ResponseInterface + { + $url = \sprintf('https://%s/openai/deployments/%s/audio/translations', $this->baseUrl, $this->deployment); + + return $this->httpClient->request('POST', $url, [ + 'headers' => [ + 'api-key' => $this->apiKey, + 'Content-Type' => 'multipart/form-data', + ], + 'query' => ['api-version' => $this->apiVersion], + 'body' => array_merge($options, $payload), + ]); + } +} diff --git a/src/platform/src/Bridge/Bedrock/Anthropic/ClaudeHandler.php b/src/platform/src/Bridge/Bedrock/Anthropic/ClaudeHandler.php new file mode 100644 index 000000000..08977da56 --- /dev/null +++ b/src/platform/src/Bridge/Bedrock/Anthropic/ClaudeHandler.php @@ -0,0 +1,97 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Bedrock\Anthropic; + +use AsyncAws\BedrockRuntime\BedrockRuntimeClient; +use AsyncAws\BedrockRuntime\Input\InvokeModelRequest; +use AsyncAws\BedrockRuntime\Result\InvokeModelResponse; +use Symfony\AI\Platform\Bridge\Anthropic\Claude; +use Symfony\AI\Platform\Bridge\Bedrock\BedrockModelClient; +use Symfony\AI\Platform\Exception\RuntimeException; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\Response\ResponseInterface as LlmResponse; +use Symfony\AI\Platform\Response\TextResponse; +use Symfony\AI\Platform\Response\ToolCall; +use Symfony\AI\Platform\Response\ToolCallResponse; + +/** + * @author Björn Altmann + */ +final readonly class ClaudeHandler implements BedrockModelClient +{ + public function __construct( + private BedrockRuntimeClient $bedrockRuntimeClient, + private string $version = '2023-05-31', + ) { + } + + public function supports(Model $model): bool + { + return $model instanceof Claude; + } + + public function request(Model $model, array|string $payload, array $options = []): LlmResponse + { + unset($payload['model']); + + if (isset($options['tools'])) { + $options['tool_choice'] = ['type' => 'auto']; + } + + if (!isset($options['anthropic_version'])) { + $options['anthropic_version'] = 'bedrock-'.$this->version; + } + + $request = [ + 'modelId' => $this->getModelId($model), + 'contentType' => 'application/json', + 'body' => json_encode(array_merge($options, $payload), \JSON_THROW_ON_ERROR), + ]; + + $invokeModelResponse = $this->bedrockRuntimeClient->invokeModel(new InvokeModelRequest($request)); + + return $this->convert($invokeModelResponse); + } + + public function convert(InvokeModelResponse $bedrockResponse): LlmResponse + { + $data = json_decode($bedrockResponse->getBody(), true, 512, \JSON_THROW_ON_ERROR); + + if (!isset($data['content']) || 0 === \count($data['content'])) { + throw new RuntimeException('Response does not contain any content'); + } + + if (!isset($data['content'][0]['text']) && !isset($data['content'][0]['type'])) { + throw new RuntimeException('Response content does not contain any text or type'); + } + + $toolCalls = []; + foreach ($data['content'] as $content) { + if ('tool_use' === $content['type']) { + $toolCalls[] = new ToolCall($content['id'], $content['name'], $content['input']); + } + } + if (!empty($toolCalls)) { + return new ToolCallResponse(...$toolCalls); + } + + return new TextResponse($data['content'][0]['text']); + } + + private function getModelId(Model $model): string + { + $configuredRegion = $this->bedrockRuntimeClient->getConfiguration()->get('region'); + $regionPrefix = substr((string) $configuredRegion, 0, 2); + + return $regionPrefix.'.anthropic.'.$model->getName().'-v1:0'; + } +} diff --git a/src/platform/src/Bridge/Bedrock/BedrockModelClient.php b/src/platform/src/Bridge/Bedrock/BedrockModelClient.php new file mode 100644 index 000000000..25fe1121c --- /dev/null +++ b/src/platform/src/Bridge/Bedrock/BedrockModelClient.php @@ -0,0 +1,29 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Bedrock; + +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\Response\ResponseInterface as LlmResponse; + +/** + * @author Björn Altmann + */ +interface BedrockModelClient +{ + public function supports(Model $model): bool; + + /** + * @param array|string $payload + * @param array $options + */ + public function request(Model $model, array|string $payload, array $options = []): LlmResponse; +} diff --git a/src/platform/src/Bridge/Bedrock/Meta/LlamaModelClient.php b/src/platform/src/Bridge/Bedrock/Meta/LlamaModelClient.php new file mode 100644 index 000000000..4b6e31a57 --- /dev/null +++ b/src/platform/src/Bridge/Bedrock/Meta/LlamaModelClient.php @@ -0,0 +1,68 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Bedrock\Meta; + +use AsyncAws\BedrockRuntime\BedrockRuntimeClient; +use AsyncAws\BedrockRuntime\Input\InvokeModelRequest; +use AsyncAws\BedrockRuntime\Result\InvokeModelResponse; +use Symfony\AI\Platform\Bridge\Bedrock\BedrockModelClient; +use Symfony\AI\Platform\Bridge\Meta\Llama; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\Response\ResponseInterface as LlmResponse; +use Symfony\AI\Platform\Response\TextResponse; + +/** + * @author Björn Altmann + */ +class LlamaModelClient implements BedrockModelClient +{ + public function __construct( + private readonly BedrockRuntimeClient $bedrockRuntimeClient, + ) { + } + + public function supports(Model $model): bool + { + return $model instanceof Llama; + } + + public function request(Model $model, array|string $payload, array $options = []): LlmResponse + { + $response = $this->bedrockRuntimeClient->invokeModel(new InvokeModelRequest([ + 'modelId' => $this->getModelId($model), + 'contentType' => 'application/json', + 'body' => json_encode($payload, \JSON_THROW_ON_ERROR), + ])); + + return $this->convert($response); + } + + public function convert(InvokeModelResponse $bedrockResponse): LlmResponse + { + $responseBody = json_decode($bedrockResponse->getBody(), true, 512, \JSON_THROW_ON_ERROR); + + if (!isset($responseBody['generation'])) { + throw new \RuntimeException('Response does not contain any content'); + } + + return new TextResponse($responseBody['generation']); + } + + private function getModelId(Model $model): string + { + $configuredRegion = $this->bedrockRuntimeClient->getConfiguration()->get('region'); + $regionPrefix = substr((string) $configuredRegion, 0, 2); + $modifiedModelName = str_replace('llama-3', 'llama3', $model->getName()); + + return $regionPrefix.'.meta.'.str_replace('.', '-', $modifiedModelName).'-v1:0'; + } +} diff --git a/src/platform/src/Bridge/Bedrock/Nova/Contract/AssistantMessageNormalizer.php b/src/platform/src/Bridge/Bedrock/Nova/Contract/AssistantMessageNormalizer.php new file mode 100644 index 000000000..bf6378fd0 --- /dev/null +++ b/src/platform/src/Bridge/Bedrock/Nova/Contract/AssistantMessageNormalizer.php @@ -0,0 +1,72 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Bedrock\Nova\Contract; + +use Symfony\AI\Platform\Bridge\Bedrock\Nova\Nova; +use Symfony\AI\Platform\Contract\Normalizer\ModelContractNormalizer; +use Symfony\AI\Platform\Message\AssistantMessage; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\Response\ToolCall; + +/** + * @author Christopher Hertel + */ +final class AssistantMessageNormalizer extends ModelContractNormalizer +{ + protected function supportedDataClass(): string + { + return AssistantMessage::class; + } + + protected function supportsModel(Model $model): bool + { + return $model instanceof Nova; + } + + /** + * @param AssistantMessage $data + * + * @return array{ + * role: 'assistant', + * content: array + * } + */ + public function normalize(mixed $data, ?string $format = null, array $context = []): array + { + if ($data->hasToolCalls()) { + return [ + 'role' => 'assistant', + 'content' => array_map(static function (ToolCall $toolCall) { + return [ + 'toolUse' => [ + 'toolUseId' => $toolCall->id, + 'name' => $toolCall->name, + 'input' => empty($toolCall->arguments) ? new \stdClass() : $toolCall->arguments, + ], + ]; + }, $data->toolCalls), + ]; + } + + return [ + 'role' => 'assistant', + 'content' => [['text' => $data->content]], + ]; + } +} diff --git a/src/platform/src/Bridge/Bedrock/Nova/Contract/MessageBagNormalizer.php b/src/platform/src/Bridge/Bedrock/Nova/Contract/MessageBagNormalizer.php new file mode 100644 index 000000000..45a3338db --- /dev/null +++ b/src/platform/src/Bridge/Bedrock/Nova/Contract/MessageBagNormalizer.php @@ -0,0 +1,58 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Bedrock\Nova\Contract; + +use Symfony\AI\Platform\Bridge\Bedrock\Nova\Nova; +use Symfony\AI\Platform\Contract\Normalizer\ModelContractNormalizer; +use Symfony\AI\Platform\Message\MessageBagInterface; +use Symfony\AI\Platform\Model; +use Symfony\Component\Serializer\Normalizer\NormalizerAwareInterface; +use Symfony\Component\Serializer\Normalizer\NormalizerAwareTrait; + +/** + * @author Christopher Hertel + */ +final class MessageBagNormalizer extends ModelContractNormalizer implements NormalizerAwareInterface +{ + use NormalizerAwareTrait; + + protected function supportedDataClass(): string + { + return MessageBagInterface::class; + } + + protected function supportsModel(Model $model): bool + { + return $model instanceof Nova; + } + + /** + * @param MessageBagInterface $data + * + * @return array{ + * messages: array>, + * system?: array, + * } + */ + public function normalize(mixed $data, ?string $format = null, array $context = []): array + { + $array = []; + + if ($data->getSystemMessage()) { + $array['system'][]['text'] = $data->getSystemMessage()->content; + } + + $array['messages'] = $this->normalizer->normalize($data->withoutSystemMessage()->getMessages(), $format, $context); + + return $array; + } +} diff --git a/src/platform/src/Bridge/Bedrock/Nova/Contract/ToolCallMessageNormalizer.php b/src/platform/src/Bridge/Bedrock/Nova/Contract/ToolCallMessageNormalizer.php new file mode 100644 index 000000000..92cd2e7a8 --- /dev/null +++ b/src/platform/src/Bridge/Bedrock/Nova/Contract/ToolCallMessageNormalizer.php @@ -0,0 +1,65 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Bedrock\Nova\Contract; + +use Symfony\AI\Platform\Bridge\Bedrock\Nova\Nova; +use Symfony\AI\Platform\Contract\Normalizer\ModelContractNormalizer; +use Symfony\AI\Platform\Message\ToolCallMessage; +use Symfony\AI\Platform\Model; +use Symfony\Component\Serializer\Normalizer\NormalizerAwareInterface; +use Symfony\Component\Serializer\Normalizer\NormalizerAwareTrait; + +/** + * @author Christopher Hertel + */ +final class ToolCallMessageNormalizer extends ModelContractNormalizer implements NormalizerAwareInterface +{ + use NormalizerAwareTrait; + + protected function supportedDataClass(): string + { + return ToolCallMessage::class; + } + + protected function supportsModel(Model $model): bool + { + return $model instanceof Nova; + } + + /** + * @param ToolCallMessage $data + * + * @return array{ + * role: 'user', + * content: array, + * } + * }> + * } + */ + public function normalize(mixed $data, ?string $format = null, array $context = []): array + { + return [ + 'role' => 'user', + 'content' => [ + [ + 'toolResult' => [ + 'toolUseId' => $data->toolCall->id, + 'content' => [['json' => $data->content]], + ], + ], + ], + ]; + } +} diff --git a/src/platform/src/Bridge/Bedrock/Nova/Contract/ToolNormalizer.php b/src/platform/src/Bridge/Bedrock/Nova/Contract/ToolNormalizer.php new file mode 100644 index 000000000..8bdd3793f --- /dev/null +++ b/src/platform/src/Bridge/Bedrock/Nova/Contract/ToolNormalizer.php @@ -0,0 +1,62 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Bedrock\Nova\Contract; + +use Symfony\AI\Platform\Bridge\Bedrock\Nova\Nova; +use Symfony\AI\Platform\Contract\JsonSchema\Factory; +use Symfony\AI\Platform\Contract\Normalizer\ModelContractNormalizer; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\Tool\Tool; + +/** + * @phpstan-import-type JsonSchema from Factory + * + * @author Christopher Hertel + */ +class ToolNormalizer extends ModelContractNormalizer +{ + protected function supportedDataClass(): string + { + return Tool::class; + } + + protected function supportsModel(Model $model): bool + { + return $model instanceof Nova; + } + + /** + * @param Tool $data + * + * @return array{ + * toolSpec: array{ + * name: string, + * description: string, + * inputSchema: array{ + * json: JsonSchema|array{type: 'object'} + * } + * } + * } + */ + public function normalize(mixed $data, ?string $format = null, array $context = []): array + { + return [ + 'toolSpec' => [ + 'name' => $data->name, + 'description' => $data->description, + 'inputSchema' => [ + 'json' => $data->parameters ?? new \stdClass(), + ], + ], + ]; + } +} diff --git a/src/platform/src/Bridge/Bedrock/Nova/Contract/UserMessageNormalizer.php b/src/platform/src/Bridge/Bedrock/Nova/Contract/UserMessageNormalizer.php new file mode 100644 index 000000000..31776d5a1 --- /dev/null +++ b/src/platform/src/Bridge/Bedrock/Nova/Contract/UserMessageNormalizer.php @@ -0,0 +1,72 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Bedrock\Nova\Contract; + +use Symfony\AI\Platform\Bridge\Bedrock\Nova\Nova; +use Symfony\AI\Platform\Contract\Normalizer\ModelContractNormalizer; +use Symfony\AI\Platform\Exception\RuntimeException; +use Symfony\AI\Platform\Message\Content\Image; +use Symfony\AI\Platform\Message\Content\Text; +use Symfony\AI\Platform\Message\UserMessage; +use Symfony\AI\Platform\Model; + +use function Symfony\Component\String\u; + +/** + * @author Christopher Hertel + */ +final class UserMessageNormalizer extends ModelContractNormalizer +{ + protected function supportedDataClass(): string + { + return UserMessage::class; + } + + protected function supportsModel(Model $model): bool + { + return $model instanceof Nova; + } + + /** + * @param UserMessage $data + * + * @return array{ + * role: 'user', + * content: array + * } + */ + public function normalize(mixed $data, ?string $format = null, array $context = []): array + { + $array = ['role' => $data->getRole()->value]; + + foreach ($data->content as $value) { + $contentPart = []; + if ($value instanceof Text) { + $contentPart['text'] = $value->text; + } elseif ($value instanceof Image) { + $contentPart['image']['format'] = u($value->getFormat())->replace('image/', '')->replace('jpg', 'jpeg')->toString(); + $contentPart['image']['source']['bytes'] = $value->asBase64(); + } else { + throw new RuntimeException('Unsupported message type.'); + } + $array['content'][] = $contentPart; + } + + return $array; + } +} diff --git a/src/platform/src/Bridge/Bedrock/Nova/Nova.php b/src/platform/src/Bridge/Bedrock/Nova/Nova.php new file mode 100644 index 000000000..826ae533e --- /dev/null +++ b/src/platform/src/Bridge/Bedrock/Nova/Nova.php @@ -0,0 +1,46 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Bedrock\Nova; + +use Symfony\AI\Platform\Capability; +use Symfony\AI\Platform\Model; + +/** + * @author Björn Altmann + */ +final class Nova extends Model +{ + public const MICRO = 'nova-micro'; + public const LITE = 'nova-lite'; + public const PRO = 'nova-pro'; + public const PREMIER = 'nova-premier'; + + /** + * @param array $options The default options for the model usage + */ + public function __construct( + string $name = self::PRO, + array $options = ['temperature' => 1.0, 'max_tokens' => 1000], + ) { + $capabilities = [ + Capability::INPUT_MESSAGES, + Capability::OUTPUT_TEXT, + Capability::TOOL_CALLING, + ]; + + if (self::MICRO !== $name) { + $capabilities[] = Capability::INPUT_IMAGE; + } + + parent::__construct($name, $capabilities, $options); + } +} diff --git a/src/platform/src/Bridge/Bedrock/Nova/NovaHandler.php b/src/platform/src/Bridge/Bedrock/Nova/NovaHandler.php new file mode 100644 index 000000000..9e7624c3e --- /dev/null +++ b/src/platform/src/Bridge/Bedrock/Nova/NovaHandler.php @@ -0,0 +1,98 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Bedrock\Nova; + +use AsyncAws\BedrockRuntime\BedrockRuntimeClient; +use AsyncAws\BedrockRuntime\Input\InvokeModelRequest; +use AsyncAws\BedrockRuntime\Result\InvokeModelResponse; +use Symfony\AI\Platform\Bridge\Bedrock\BedrockModelClient; +use Symfony\AI\Platform\Exception\RuntimeException; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\Response\ResponseInterface as LlmResponse; +use Symfony\AI\Platform\Response\TextResponse; +use Symfony\AI\Platform\Response\ToolCall; +use Symfony\AI\Platform\Response\ToolCallResponse; + +/** + * @author Björn Altmann + */ +class NovaHandler implements BedrockModelClient +{ + public function __construct( + private readonly BedrockRuntimeClient $bedrockRuntimeClient, + ) { + } + + public function supports(Model $model): bool + { + return $model instanceof Nova; + } + + public function request(Model $model, array|string $payload, array $options = []): LlmResponse + { + $modelOptions = []; + if (isset($options['tools'])) { + $modelOptions['toolConfig']['tools'] = $options['tools']; + } + + if (isset($options['temperature'])) { + $modelOptions['inferenceConfig']['temperature'] = $options['temperature']; + } + + if (isset($options['max_tokens'])) { + $modelOptions['inferenceConfig']['maxTokens'] = $options['max_tokens']; + } + + $request = [ + 'modelId' => $this->getModelId($model), + 'contentType' => 'application/json', + 'body' => json_encode(array_merge($payload, $modelOptions), \JSON_THROW_ON_ERROR), + ]; + + $invokeModelResponse = $this->bedrockRuntimeClient->invokeModel(new InvokeModelRequest($request)); + + return $this->convert($invokeModelResponse); + } + + public function convert(InvokeModelResponse $bedrockResponse): LlmResponse + { + $data = json_decode($bedrockResponse->getBody(), true, 512, \JSON_THROW_ON_ERROR); + + if (!isset($data['output']) || 0 === \count($data['output'])) { + throw new RuntimeException('Response does not contain any content'); + } + + if (!isset($data['output']['message']['content'][0]['text'])) { + throw new RuntimeException('Response content does not contain any text'); + } + + $toolCalls = []; + foreach ($data['output']['message']['content'] as $content) { + if (isset($content['toolUse'])) { + $toolCalls[] = new ToolCall($content['toolUse']['toolUseId'], $content['toolUse']['name'], $content['toolUse']['input']); + } + } + if (!empty($toolCalls)) { + return new ToolCallResponse(...$toolCalls); + } + + return new TextResponse($data['output']['message']['content'][0]['text']); + } + + private function getModelId(Model $model): string + { + $configuredRegion = $this->bedrockRuntimeClient->getConfiguration()->get('region'); + $regionPrefix = substr((string) $configuredRegion, 0, 2); + + return $regionPrefix.'.amazon.'.$model->getName().'-v1:0'; + } +} diff --git a/src/platform/src/Bridge/Bedrock/Platform.php b/src/platform/src/Bridge/Bedrock/Platform.php new file mode 100644 index 000000000..23f23dde6 --- /dev/null +++ b/src/platform/src/Bridge/Bedrock/Platform.php @@ -0,0 +1,85 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Bedrock; + +use Symfony\AI\Platform\Bridge\Anthropic\Contract as AnthropicContract; +use Symfony\AI\Platform\Bridge\Bedrock\Nova\Contract as NovaContract; +use Symfony\AI\Platform\Bridge\Meta\Contract as LlamaContract; +use Symfony\AI\Platform\Contract; +use Symfony\AI\Platform\Exception\RuntimeException; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\PlatformInterface; +use Symfony\AI\Platform\Response\ResponseInterface; + +/** + * @author Björn Altmann + */ +class Platform implements PlatformInterface +{ + /** + * @var BedrockModelClient[] + */ + private readonly array $modelClients; + + /** + * @param iterable $modelClients + */ + public function __construct( + iterable $modelClients, + private ?Contract $contract = null, + ) { + $this->contract = $contract ?? Contract::create( + new AnthropicContract\AssistantMessageNormalizer(), + new AnthropicContract\DocumentNormalizer(), + new AnthropicContract\DocumentUrlNormalizer(), + new AnthropicContract\ImageNormalizer(), + new AnthropicContract\ImageUrlNormalizer(), + new AnthropicContract\MessageBagNormalizer(), + new AnthropicContract\ToolCallMessageNormalizer(), + new AnthropicContract\ToolNormalizer(), + new LlamaContract\MessageBagNormalizer(), + new NovaContract\AssistantMessageNormalizer(), + new NovaContract\MessageBagNormalizer(), + new NovaContract\ToolCallMessageNormalizer(), + new NovaContract\ToolNormalizer(), + new NovaContract\UserMessageNormalizer(), + ); + $this->modelClients = $modelClients instanceof \Traversable ? iterator_to_array($modelClients) : $modelClients; + } + + public function request(Model $model, array|string|object $input, array $options = []): ResponseInterface + { + $payload = $this->contract->createRequestPayload($model, $input); + $options = array_merge($model->getOptions(), $options); + + if (isset($options['tools'])) { + $options['tools'] = $this->contract->createToolOption($options['tools'], $model); + } + + return $this->doRequest($model, $payload, $options); + } + + /** + * @param array|string $payload + * @param array $options + */ + private function doRequest(Model $model, array|string $payload, array $options = []): ResponseInterface + { + foreach ($this->modelClients as $modelClient) { + if ($modelClient->supports($model)) { + return $modelClient->request($model, $payload, $options); + } + } + + throw new RuntimeException('No response factory registered for model "'.$model::class.'" with given input.'); + } +} diff --git a/src/platform/src/Bridge/Bedrock/PlatformFactory.php b/src/platform/src/Bridge/Bedrock/PlatformFactory.php new file mode 100644 index 000000000..bd5d8784c --- /dev/null +++ b/src/platform/src/Bridge/Bedrock/PlatformFactory.php @@ -0,0 +1,33 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Bedrock; + +use AsyncAws\BedrockRuntime\BedrockRuntimeClient; +use Symfony\AI\Platform\Bridge\Bedrock\Anthropic\ClaudeHandler; +use Symfony\AI\Platform\Bridge\Bedrock\Meta\LlamaModelClient; +use Symfony\AI\Platform\Bridge\Bedrock\Nova\NovaHandler; + +/** + * @author Björn Altmann + */ +final readonly class PlatformFactory +{ + public static function create( + BedrockRuntimeClient $bedrockRuntimeClient = new BedrockRuntimeClient(), + ): Platform { + $modelClient[] = new ClaudeHandler($bedrockRuntimeClient); + $modelClient[] = new NovaHandler($bedrockRuntimeClient); + $modelClient[] = new LlamaModelClient($bedrockRuntimeClient); + + return new Platform($modelClient); + } +} diff --git a/src/platform/src/Bridge/Google/Contract/AssistantMessageNormalizer.php b/src/platform/src/Bridge/Google/Contract/AssistantMessageNormalizer.php new file mode 100644 index 000000000..11663a747 --- /dev/null +++ b/src/platform/src/Bridge/Google/Contract/AssistantMessageNormalizer.php @@ -0,0 +1,49 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Google\Contract; + +use Symfony\AI\Platform\Bridge\Google\Gemini; +use Symfony\AI\Platform\Contract\Normalizer\ModelContractNormalizer; +use Symfony\AI\Platform\Message\AssistantMessage; +use Symfony\AI\Platform\Model; +use Symfony\Component\Serializer\Normalizer\NormalizerAwareInterface; +use Symfony\Component\Serializer\Normalizer\NormalizerAwareTrait; + +/** + * @author Christopher Hertel + */ +final class AssistantMessageNormalizer extends ModelContractNormalizer implements NormalizerAwareInterface +{ + use NormalizerAwareTrait; + + protected function supportedDataClass(): string + { + return AssistantMessage::class; + } + + protected function supportsModel(Model $model): bool + { + return $model instanceof Gemini; + } + + /** + * @param AssistantMessage $data + * + * @return array{array{text: string}} + */ + public function normalize(mixed $data, ?string $format = null, array $context = []): array + { + return [ + ['text' => $data->content], + ]; + } +} diff --git a/src/platform/src/Bridge/Google/Contract/MessageBagNormalizer.php b/src/platform/src/Bridge/Google/Contract/MessageBagNormalizer.php new file mode 100644 index 000000000..1321e2874 --- /dev/null +++ b/src/platform/src/Bridge/Google/Contract/MessageBagNormalizer.php @@ -0,0 +1,69 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Google\Contract; + +use Symfony\AI\Platform\Bridge\Google\Gemini; +use Symfony\AI\Platform\Contract\Normalizer\ModelContractNormalizer; +use Symfony\AI\Platform\Message\MessageBagInterface; +use Symfony\AI\Platform\Message\Role; +use Symfony\AI\Platform\Model; +use Symfony\Component\Serializer\Normalizer\NormalizerAwareInterface; +use Symfony\Component\Serializer\Normalizer\NormalizerAwareTrait; + +/** + * @author Christopher Hertel + */ +final class MessageBagNormalizer extends ModelContractNormalizer implements NormalizerAwareInterface +{ + use NormalizerAwareTrait; + + protected function supportedDataClass(): string + { + return MessageBagInterface::class; + } + + protected function supportsModel(Model $model): bool + { + return $model instanceof Gemini; + } + + /** + * @param MessageBagInterface $data + * + * @return array{ + * contents: list + * }>, + * system_instruction?: array{parts: array{text: string}} + * } + */ + public function normalize(mixed $data, ?string $format = null, array $context = []): array + { + $array = ['contents' => []]; + + if (null !== $systemMessage = $data->getSystemMessage()) { + $array['system_instruction'] = [ + 'parts' => ['text' => $systemMessage->content], + ]; + } + + foreach ($data->withoutSystemMessage()->getMessages() as $message) { + $array['contents'][] = [ + 'role' => $message->getRole()->equals(Role::Assistant) ? 'model' : 'user', + 'parts' => $this->normalizer->normalize($message, $format, $context), + ]; + } + + return $array; + } +} diff --git a/src/platform/src/Bridge/Google/Contract/UserMessageNormalizer.php b/src/platform/src/Bridge/Google/Contract/UserMessageNormalizer.php new file mode 100644 index 000000000..2f41462b7 --- /dev/null +++ b/src/platform/src/Bridge/Google/Contract/UserMessageNormalizer.php @@ -0,0 +1,58 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Google\Contract; + +use Symfony\AI\Platform\Bridge\Google\Gemini; +use Symfony\AI\Platform\Contract\Normalizer\ModelContractNormalizer; +use Symfony\AI\Platform\Message\Content\Image; +use Symfony\AI\Platform\Message\Content\Text; +use Symfony\AI\Platform\Message\UserMessage; +use Symfony\AI\Platform\Model; + +/** + * @author Christopher Hertel + */ +final class UserMessageNormalizer extends ModelContractNormalizer +{ + protected function supportedDataClass(): string + { + return UserMessage::class; + } + + protected function supportsModel(Model $model): bool + { + return $model instanceof Gemini; + } + + /** + * @param UserMessage $data + * + * @return list + */ + public function normalize(mixed $data, ?string $format = null, array $context = []): array + { + $parts = []; + foreach ($data->content as $content) { + if ($content instanceof Text) { + $parts[] = ['text' => $content->text]; + } + if ($content instanceof Image) { + $parts[] = ['inline_data' => [ + 'mime_type' => $content->getFormat(), + 'data' => $content->asBase64(), + ]]; + } + } + + return $parts; + } +} diff --git a/src/platform/src/Bridge/Google/Gemini.php b/src/platform/src/Bridge/Google/Gemini.php new file mode 100644 index 000000000..ec52fc787 --- /dev/null +++ b/src/platform/src/Bridge/Google/Gemini.php @@ -0,0 +1,41 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Google; + +use Symfony\AI\Platform\Capability; +use Symfony\AI\Platform\Model; + +/** + * @author Roy Garrido + */ +class Gemini extends Model +{ + public const GEMINI_2_FLASH = 'gemini-2.0-flash'; + public const GEMINI_2_PRO = 'gemini-2.0-pro-exp-02-05'; + public const GEMINI_2_FLASH_LITE = 'gemini-2.0-flash-lite-preview-02-05'; + public const GEMINI_2_FLASH_THINKING = 'gemini-2.0-flash-thinking-exp-01-21'; + public const GEMINI_1_5_FLASH = 'gemini-1.5-flash'; + + /** + * @param array $options The default options for the model usage + */ + public function __construct(string $name = self::GEMINI_2_PRO, array $options = ['temperature' => 1.0]) + { + $capabilities = [ + Capability::INPUT_MESSAGES, + Capability::INPUT_IMAGE, + Capability::OUTPUT_STREAMING, + ]; + + parent::__construct($name, $capabilities, $options); + } +} diff --git a/src/platform/src/Bridge/Google/ModelHandler.php b/src/platform/src/Bridge/Google/ModelHandler.php new file mode 100644 index 000000000..57bc4a0ed --- /dev/null +++ b/src/platform/src/Bridge/Google/ModelHandler.php @@ -0,0 +1,132 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Google; + +use Symfony\AI\Platform\Exception\RuntimeException; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\ModelClientInterface; +use Symfony\AI\Platform\Response\ResponseInterface as LlmResponse; +use Symfony\AI\Platform\Response\StreamResponse; +use Symfony\AI\Platform\Response\TextResponse; +use Symfony\AI\Platform\ResponseConverterInterface; +use Symfony\Component\HttpClient\EventSourceHttpClient; +use Symfony\Contracts\HttpClient\Exception\ClientExceptionInterface; +use Symfony\Contracts\HttpClient\Exception\DecodingExceptionInterface; +use Symfony\Contracts\HttpClient\Exception\RedirectionExceptionInterface; +use Symfony\Contracts\HttpClient\Exception\ServerExceptionInterface; +use Symfony\Contracts\HttpClient\Exception\TransportExceptionInterface; +use Symfony\Contracts\HttpClient\HttpClientInterface; +use Symfony\Contracts\HttpClient\ResponseInterface; + +/** + * @author Roy Garrido + */ +final readonly class ModelHandler implements ModelClientInterface, ResponseConverterInterface +{ + private EventSourceHttpClient $httpClient; + + public function __construct( + HttpClientInterface $httpClient, + #[\SensitiveParameter] private string $apiKey, + ) { + $this->httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient); + } + + public function supports(Model $model): bool + { + return $model instanceof Gemini; + } + + /** + * @throws TransportExceptionInterface + */ + public function request(Model $model, array|string $payload, array $options = []): ResponseInterface + { + $url = \sprintf( + 'https://generativelanguage.googleapis.com/v1beta/models/%s:%s', + $model->getName(), + $options['stream'] ?? false ? 'streamGenerateContent' : 'generateContent', + ); + + $generationConfig = ['generationConfig' => $options]; + unset($generationConfig['generationConfig']['stream']); + + return $this->httpClient->request('POST', $url, [ + 'headers' => [ + 'x-goog-api-key' => $this->apiKey, + ], + 'json' => array_merge($generationConfig, $payload), + ]); + } + + /** + * @throws TransportExceptionInterface + * @throws ServerExceptionInterface + * @throws RedirectionExceptionInterface + * @throws DecodingExceptionInterface + * @throws ClientExceptionInterface + */ + public function convert(ResponseInterface $response, array $options = []): LlmResponse + { + if ($options['stream'] ?? false) { + return new StreamResponse($this->convertStream($response)); + } + + $data = $response->toArray(); + + if (!isset($data['candidates'][0]['content']['parts'][0]['text'])) { + throw new RuntimeException('Response does not contain any content'); + } + + return new TextResponse($data['candidates'][0]['content']['parts'][0]['text']); + } + + private function convertStream(ResponseInterface $response): \Generator + { + foreach ((new EventSourceHttpClient())->stream($response) as $chunk) { + if ($chunk->isFirst() || $chunk->isLast()) { + continue; + } + + $jsonDelta = trim($chunk->getContent()); + + // Remove leading/trailing brackets + if (str_starts_with($jsonDelta, '[') || str_starts_with($jsonDelta, ',')) { + $jsonDelta = substr($jsonDelta, 1); + } + if (str_ends_with($jsonDelta, ']')) { + $jsonDelta = substr($jsonDelta, 0, -1); + } + + // Split in case of multiple JSON objects + $deltas = explode(",\r\n", $jsonDelta); + + foreach ($deltas as $delta) { + if ('' === $delta) { + continue; + } + + try { + $data = json_decode($delta, true, 512, \JSON_THROW_ON_ERROR); + } catch (\JsonException $e) { + throw new RuntimeException('Failed to decode JSON response', 0, $e); + } + + if (!isset($data['candidates'][0]['content']['parts'][0]['text'])) { + continue; + } + + yield $data['candidates'][0]['content']['parts'][0]['text']; + } + } + } +} diff --git a/src/platform/src/Bridge/Google/PlatformFactory.php b/src/platform/src/Bridge/Google/PlatformFactory.php new file mode 100644 index 000000000..cab1c3b3d --- /dev/null +++ b/src/platform/src/Bridge/Google/PlatformFactory.php @@ -0,0 +1,41 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Google; + +use Symfony\AI\Platform\Bridge\Google\Contract\AssistantMessageNormalizer; +use Symfony\AI\Platform\Bridge\Google\Contract\MessageBagNormalizer; +use Symfony\AI\Platform\Bridge\Google\Contract\UserMessageNormalizer; +use Symfony\AI\Platform\Contract; +use Symfony\AI\Platform\Platform; +use Symfony\Component\HttpClient\EventSourceHttpClient; +use Symfony\Contracts\HttpClient\HttpClientInterface; + +/** + * @author Roy Garrido + */ +final readonly class PlatformFactory +{ + public static function create( + #[\SensitiveParameter] + string $apiKey, + ?HttpClientInterface $httpClient = null, + ): Platform { + $httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient); + $responseHandler = new ModelHandler($httpClient, $apiKey); + + return new Platform([$responseHandler], [$responseHandler], Contract::create( + new AssistantMessageNormalizer(), + new MessageBagNormalizer(), + new UserMessageNormalizer(), + )); + } +} diff --git a/src/platform/src/Bridge/HuggingFace/ApiClient.php b/src/platform/src/Bridge/HuggingFace/ApiClient.php new file mode 100644 index 000000000..4a73325eb --- /dev/null +++ b/src/platform/src/Bridge/HuggingFace/ApiClient.php @@ -0,0 +1,43 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\HuggingFace; + +use Symfony\AI\Platform\Model; +use Symfony\Component\HttpClient\HttpClient; +use Symfony\Contracts\HttpClient\HttpClientInterface; + +/** + * @author Christopher Hertel + */ +final class ApiClient +{ + public function __construct( + private ?HttpClientInterface $httpClient = null, + ) { + $this->httpClient = $httpClient ?? HttpClient::create(); + } + + /** + * @return Model[] + */ + public function models(?string $provider, ?string $task): array + { + $response = $this->httpClient->request('GET', 'https://huggingface.co/api/models', [ + 'query' => [ + 'inference_provider' => $provider, + 'pipeline_tag' => $task, + ], + ]); + + return array_map(fn (array $model) => new Model($model['id']), $response->toArray()); + } +} diff --git a/src/platform/src/Bridge/HuggingFace/Contract/FileNormalizer.php b/src/platform/src/Bridge/HuggingFace/Contract/FileNormalizer.php new file mode 100644 index 000000000..8c5e7a0b0 --- /dev/null +++ b/src/platform/src/Bridge/HuggingFace/Contract/FileNormalizer.php @@ -0,0 +1,48 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\HuggingFace\Contract; + +use Symfony\AI\Platform\Contract\Normalizer\ModelContractNormalizer; +use Symfony\AI\Platform\Message\Content\File; +use Symfony\AI\Platform\Model; + +/** + * @author Christopher Hertel + */ +class FileNormalizer extends ModelContractNormalizer +{ + protected function supportedDataClass(): string + { + return File::class; + } + + protected function supportsModel(Model $model): bool + { + return true; + } + + /** + * @param File $data + * + * @return array{ + * headers: array<'Content-Type', string>, + * body: string + * } + */ + public function normalize(mixed $data, ?string $format = null, array $context = []): array + { + return [ + 'headers' => ['Content-Type' => $data->getFormat()], + 'body' => $data->asBinary(), + ]; + } +} diff --git a/src/platform/src/Bridge/HuggingFace/Contract/MessageBagNormalizer.php b/src/platform/src/Bridge/HuggingFace/Contract/MessageBagNormalizer.php new file mode 100644 index 000000000..29f8a4581 --- /dev/null +++ b/src/platform/src/Bridge/HuggingFace/Contract/MessageBagNormalizer.php @@ -0,0 +1,54 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\HuggingFace\Contract; + +use Symfony\AI\Platform\Contract\Normalizer\ModelContractNormalizer; +use Symfony\AI\Platform\Message\MessageBagInterface; +use Symfony\AI\Platform\Model; +use Symfony\Component\Serializer\Normalizer\NormalizerAwareInterface; +use Symfony\Component\Serializer\Normalizer\NormalizerAwareTrait; + +/** + * @author Christopher Hertel + */ +class MessageBagNormalizer extends ModelContractNormalizer implements NormalizerAwareInterface +{ + use NormalizerAwareTrait; + + protected function supportedDataClass(): string + { + return MessageBagInterface::class; + } + + protected function supportsModel(Model $model): bool + { + return true; + } + + /** + * @param MessageBagInterface $data + * + * @return array{ + * headers: array<'Content-Type', 'application/json'>, + * json: array{messages: array} + * } + */ + public function normalize(mixed $data, ?string $format = null, array $context = []): array + { + return [ + 'headers' => ['Content-Type' => 'application/json'], + 'json' => [ + 'messages' => $this->normalizer->normalize($data->getMessages(), $format, $context), + ], + ]; + } +} diff --git a/src/platform/src/Bridge/HuggingFace/ModelClient.php b/src/platform/src/Bridge/HuggingFace/ModelClient.php new file mode 100644 index 000000000..dc6595e47 --- /dev/null +++ b/src/platform/src/Bridge/HuggingFace/ModelClient.php @@ -0,0 +1,94 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\HuggingFace; + +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\ModelClientInterface as PlatformModelClient; +use Symfony\Component\HttpClient\EventSourceHttpClient; +use Symfony\Contracts\HttpClient\HttpClientInterface; +use Symfony\Contracts\HttpClient\ResponseInterface; + +/** + * @author Christopher Hertel + */ +final readonly class ModelClient implements PlatformModelClient +{ + private EventSourceHttpClient $httpClient; + + public function __construct( + HttpClientInterface $httpClient, + private string $provider, + #[\SensitiveParameter] + private string $apiKey, + ) { + $this->httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient); + } + + public function supports(Model $model): bool + { + return true; + } + + /** + * The difference in HuggingFace here is that we treat the payload as the options for the request not only the body. + */ + public function request(Model $model, array|string $payload, array $options = []): ResponseInterface + { + // Extract task from options if provided + $task = $options['task'] ?? null; + unset($options['task']); + + return $this->httpClient->request('POST', $this->getUrl($model, $task), [ + 'auth_bearer' => $this->apiKey, + ...$this->getPayload($payload, $options), + ]); + } + + private function getUrl(Model $model, ?string $task): string + { + $endpoint = Task::FEATURE_EXTRACTION === $task ? 'pipeline/feature-extraction' : 'models'; + $url = \sprintf('https://router.huggingface.co/%s/%s/%s', $this->provider, $endpoint, $model->getName()); + + if (Task::CHAT_COMPLETION === $task) { + $url .= '/v1/chat/completions'; + } + + return $url; + } + + /** + * @param array $payload + * @param array $options + * + * @return array + */ + private function getPayload(array|string $payload, array $options): array + { + // Expect JSON input if string or not + if (\is_string($payload) || !(isset($payload['body']) || isset($payload['json']))) { + $payload = ['json' => ['inputs' => $payload]]; + + if (0 !== \count($options)) { + $payload['json']['parameters'] = $options; + } + } + + // Merge options into JSON payload + if (isset($payload['json'])) { + $payload['json'] = array_merge($payload['json'], $options); + } + + $payload['headers'] ??= ['Content-Type' => 'application/json']; + + return $payload; + } +} diff --git a/src/platform/src/Bridge/HuggingFace/Output/Classification.php b/src/platform/src/Bridge/HuggingFace/Output/Classification.php new file mode 100644 index 000000000..48775e270 --- /dev/null +++ b/src/platform/src/Bridge/HuggingFace/Output/Classification.php @@ -0,0 +1,24 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\HuggingFace\Output; + +/** + * @author Christopher Hertel + */ +final readonly class Classification +{ + public function __construct( + public string $label, + public float $score, + ) { + } +} diff --git a/src/platform/src/Bridge/HuggingFace/Output/ClassificationResult.php b/src/platform/src/Bridge/HuggingFace/Output/ClassificationResult.php new file mode 100644 index 000000000..aa28ab6b6 --- /dev/null +++ b/src/platform/src/Bridge/HuggingFace/Output/ClassificationResult.php @@ -0,0 +1,36 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\HuggingFace\Output; + +/** + * @author Christopher Hertel + */ +final class ClassificationResult +{ + /** + * @param Classification[] $classifications + */ + public function __construct( + public array $classifications, + ) { + } + + /** + * @param array $data + */ + public static function fromArray(array $data): self + { + return new self( + array_map(fn (array $item) => new Classification($item['label'], $item['score']), $data) + ); + } +} diff --git a/src/platform/src/Bridge/HuggingFace/Output/DetectedObject.php b/src/platform/src/Bridge/HuggingFace/Output/DetectedObject.php new file mode 100644 index 000000000..e81b07881 --- /dev/null +++ b/src/platform/src/Bridge/HuggingFace/Output/DetectedObject.php @@ -0,0 +1,28 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\HuggingFace\Output; + +/** + * @author Christopher Hertel + */ +final readonly class DetectedObject +{ + public function __construct( + public string $label, + public float $score, + public float $xmin, + public float $ymin, + public float $xmax, + public float $ymax, + ) { + } +} diff --git a/src/platform/src/Bridge/HuggingFace/Output/FillMaskResult.php b/src/platform/src/Bridge/HuggingFace/Output/FillMaskResult.php new file mode 100644 index 000000000..5bbfbf4cd --- /dev/null +++ b/src/platform/src/Bridge/HuggingFace/Output/FillMaskResult.php @@ -0,0 +1,42 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\HuggingFace\Output; + +/** + * @author Christopher Hertel + */ +final class FillMaskResult +{ + /** + * @param MaskFill[] $fills + */ + public function __construct( + public array $fills, + ) { + } + + /** + * @param array $data + */ + public static function fromArray(array $data): self + { + return new self(array_map( + fn (array $item) => new MaskFill( + $item['token'], + $item['token_str'], + $item['sequence'], + $item['score'], + ), + $data, + )); + } +} diff --git a/src/platform/src/Bridge/HuggingFace/Output/ImageSegment.php b/src/platform/src/Bridge/HuggingFace/Output/ImageSegment.php new file mode 100644 index 000000000..327942151 --- /dev/null +++ b/src/platform/src/Bridge/HuggingFace/Output/ImageSegment.php @@ -0,0 +1,25 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\HuggingFace\Output; + +/** + * @author Christopher Hertel + */ +final readonly class ImageSegment +{ + public function __construct( + public string $label, + public ?float $score, + public string $mask, + ) { + } +} diff --git a/src/platform/src/Bridge/HuggingFace/Output/ImageSegmentationResult.php b/src/platform/src/Bridge/HuggingFace/Output/ImageSegmentationResult.php new file mode 100644 index 000000000..999fab9d3 --- /dev/null +++ b/src/platform/src/Bridge/HuggingFace/Output/ImageSegmentationResult.php @@ -0,0 +1,36 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\HuggingFace\Output; + +/** + * @author Christopher Hertel + */ +final class ImageSegmentationResult +{ + /** + * @param ImageSegment[] $segments + */ + public function __construct( + public array $segments, + ) { + } + + /** + * @param array $data + */ + public static function fromArray(array $data): self + { + return new self( + array_map(fn (array $item) => new ImageSegment($item['label'], $item['score'], $item['mask']), $data) + ); + } +} diff --git a/src/platform/src/Bridge/HuggingFace/Output/MaskFill.php b/src/platform/src/Bridge/HuggingFace/Output/MaskFill.php new file mode 100644 index 000000000..242ead4ee --- /dev/null +++ b/src/platform/src/Bridge/HuggingFace/Output/MaskFill.php @@ -0,0 +1,26 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\HuggingFace\Output; + +/** + * @author Christopher Hertel + */ +final readonly class MaskFill +{ + public function __construct( + public int $token, + public string $tokenStr, + public string $sequence, + public float $score, + ) { + } +} diff --git a/src/platform/src/Bridge/HuggingFace/Output/ObjectDetectionResult.php b/src/platform/src/Bridge/HuggingFace/Output/ObjectDetectionResult.php new file mode 100644 index 000000000..65c868162 --- /dev/null +++ b/src/platform/src/Bridge/HuggingFace/Output/ObjectDetectionResult.php @@ -0,0 +1,44 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\HuggingFace\Output; + +/** + * @author Christopher Hertel + */ +final class ObjectDetectionResult +{ + /** + * @param DetectedObject[] $objects + */ + public function __construct( + public array $objects, + ) { + } + + /** + * @param array $data + */ + public static function fromArray(array $data): self + { + return new self(array_map( + fn (array $item) => new DetectedObject( + $item['label'], + $item['score'], + $item['box']['xmin'], + $item['box']['ymin'], + $item['box']['xmax'], + $item['box']['ymax'], + ), + $data, + )); + } +} diff --git a/src/platform/src/Bridge/HuggingFace/Output/QuestionAnsweringResult.php b/src/platform/src/Bridge/HuggingFace/Output/QuestionAnsweringResult.php new file mode 100644 index 000000000..67015b4a4 --- /dev/null +++ b/src/platform/src/Bridge/HuggingFace/Output/QuestionAnsweringResult.php @@ -0,0 +1,39 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\HuggingFace\Output; + +/** + * @author Christopher Hertel + */ +final readonly class QuestionAnsweringResult +{ + public function __construct( + public string $answer, + public int $startIndex, + public int $endIndex, + public float $score, + ) { + } + + /** + * @param array{answer: string, start: int, end: int, score: float} $data + */ + public static function fromArray(array $data): self + { + return new self( + $data['answer'], + $data['start'], + $data['end'], + $data['score'], + ); + } +} diff --git a/src/platform/src/Bridge/HuggingFace/Output/SentenceSimilarityResult.php b/src/platform/src/Bridge/HuggingFace/Output/SentenceSimilarityResult.php new file mode 100644 index 000000000..a0dea8cd9 --- /dev/null +++ b/src/platform/src/Bridge/HuggingFace/Output/SentenceSimilarityResult.php @@ -0,0 +1,34 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\HuggingFace\Output; + +/** + * @author Christopher Hertel + */ +final readonly class SentenceSimilarityResult +{ + /** + * @param array $similarities + */ + public function __construct( + public array $similarities, + ) { + } + + /** + * @param array $data + */ + public static function fromArray(array $data): self + { + return new self($data); + } +} diff --git a/src/platform/src/Bridge/HuggingFace/Output/TableQuestionAnsweringResult.php b/src/platform/src/Bridge/HuggingFace/Output/TableQuestionAnsweringResult.php new file mode 100644 index 000000000..ac5ad45bc --- /dev/null +++ b/src/platform/src/Bridge/HuggingFace/Output/TableQuestionAnsweringResult.php @@ -0,0 +1,41 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\HuggingFace\Output; + +/** + * @author Christopher Hertel + */ +final readonly class TableQuestionAnsweringResult +{ + /** + * @param array $cells + * @param array $aggregator + */ + public function __construct( + public string $answer, + public array $cells = [], + public array $aggregator = [], + ) { + } + + /** + * @param array{answer: string, cells?: array, aggregator?: array} $data + */ + public static function fromArray(array $data): self + { + return new self( + $data['answer'], + $data['cells'] ?? [], + $data['aggregator'] ?? [], + ); + } +} diff --git a/src/platform/src/Bridge/HuggingFace/Output/Token.php b/src/platform/src/Bridge/HuggingFace/Output/Token.php new file mode 100644 index 000000000..3d1dd5465 --- /dev/null +++ b/src/platform/src/Bridge/HuggingFace/Output/Token.php @@ -0,0 +1,27 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\HuggingFace\Output; + +/** + * @author Christopher Hertel + */ +final readonly class Token +{ + public function __construct( + public string $entityGroup, + public float $score, + public string $word, + public int $start, + public int $end, + ) { + } +} diff --git a/src/platform/src/Bridge/HuggingFace/Output/TokenClassificationResult.php b/src/platform/src/Bridge/HuggingFace/Output/TokenClassificationResult.php new file mode 100644 index 000000000..43dcc7c57 --- /dev/null +++ b/src/platform/src/Bridge/HuggingFace/Output/TokenClassificationResult.php @@ -0,0 +1,43 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\HuggingFace\Output; + +/** + * @author Christopher Hertel + */ +final class TokenClassificationResult +{ + /** + * @param Token[] $tokens + */ + public function __construct( + public array $tokens, + ) { + } + + /** + * @param array $data + */ + public static function fromArray(array $data): self + { + return new self(array_map( + fn (array $item) => new Token( + $item['entity_group'], + $item['score'], + $item['word'], + $item['start'], + $item['end'], + ), + $data, + )); + } +} diff --git a/src/platform/src/Bridge/HuggingFace/Output/ZeroShotClassificationResult.php b/src/platform/src/Bridge/HuggingFace/Output/ZeroShotClassificationResult.php new file mode 100644 index 000000000..1d40a5643 --- /dev/null +++ b/src/platform/src/Bridge/HuggingFace/Output/ZeroShotClassificationResult.php @@ -0,0 +1,41 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\HuggingFace\Output; + +/** + * @author Christopher Hertel + */ +final class ZeroShotClassificationResult +{ + /** + * @param array $labels + * @param array $scores + */ + public function __construct( + public array $labels, + public array $scores, + public ?string $sequence = null, + ) { + } + + /** + * @param array{labels: array, scores: array, sequence?: string} $data + */ + public static function fromArray(array $data): self + { + return new self( + $data['labels'], + $data['scores'], + $data['sequence'] ?? null, + ); + } +} diff --git a/src/platform/src/Bridge/HuggingFace/PlatformFactory.php b/src/platform/src/Bridge/HuggingFace/PlatformFactory.php new file mode 100644 index 000000000..f25c51c17 --- /dev/null +++ b/src/platform/src/Bridge/HuggingFace/PlatformFactory.php @@ -0,0 +1,43 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\HuggingFace; + +use Symfony\AI\Platform\Bridge\HuggingFace\Contract\FileNormalizer; +use Symfony\AI\Platform\Bridge\HuggingFace\Contract\MessageBagNormalizer; +use Symfony\AI\Platform\Contract; +use Symfony\AI\Platform\Platform; +use Symfony\Component\HttpClient\EventSourceHttpClient; +use Symfony\Contracts\HttpClient\HttpClientInterface; + +/** + * @author Christopher Hertel + */ +final readonly class PlatformFactory +{ + public static function create( + #[\SensitiveParameter] + string $apiKey, + string $provider = Provider::HF_INFERENCE, + ?HttpClientInterface $httpClient = null, + ): Platform { + $httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient); + + return new Platform( + [new ModelClient($httpClient, $provider, $apiKey)], + [new ResponseConverter()], + Contract::create( + new FileNormalizer(), + new MessageBagNormalizer(), + ), + ); + } +} diff --git a/src/platform/src/Bridge/HuggingFace/Provider.php b/src/platform/src/Bridge/HuggingFace/Provider.php new file mode 100644 index 000000000..2a24ea9fd --- /dev/null +++ b/src/platform/src/Bridge/HuggingFace/Provider.php @@ -0,0 +1,30 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\HuggingFace; + +/** + * @author Christopher Hertel + */ +interface Provider +{ + public const CEREBRAS = 'cerebras'; + public const COHERE = 'cohere'; + public const FAL_AI = 'fal-ai'; + public const FIREWORKS = 'fireworks-ai'; + public const HYPERBOLIC = 'hyperbolic'; + public const HF_INFERENCE = 'hf-inference'; + public const NEBIUS = 'nebius'; + public const NOVITA = 'novita'; + public const REPLICATE = 'replicate'; + public const SAMBA_NOVA = 'sambanova'; + public const TOGETHER = 'together'; +} diff --git a/src/platform/src/Bridge/HuggingFace/ResponseConverter.php b/src/platform/src/Bridge/HuggingFace/ResponseConverter.php new file mode 100644 index 000000000..5607030d2 --- /dev/null +++ b/src/platform/src/Bridge/HuggingFace/ResponseConverter.php @@ -0,0 +1,96 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\HuggingFace; + +use Symfony\AI\Platform\Bridge\HuggingFace\Output\ClassificationResult; +use Symfony\AI\Platform\Bridge\HuggingFace\Output\FillMaskResult; +use Symfony\AI\Platform\Bridge\HuggingFace\Output\ImageSegmentationResult; +use Symfony\AI\Platform\Bridge\HuggingFace\Output\ObjectDetectionResult; +use Symfony\AI\Platform\Bridge\HuggingFace\Output\QuestionAnsweringResult; +use Symfony\AI\Platform\Bridge\HuggingFace\Output\SentenceSimilarityResult; +use Symfony\AI\Platform\Bridge\HuggingFace\Output\TableQuestionAnsweringResult; +use Symfony\AI\Platform\Bridge\HuggingFace\Output\TokenClassificationResult; +use Symfony\AI\Platform\Bridge\HuggingFace\Output\ZeroShotClassificationResult; +use Symfony\AI\Platform\Exception\InvalidArgumentException; +use Symfony\AI\Platform\Exception\RuntimeException; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\Response\BinaryResponse; +use Symfony\AI\Platform\Response\ObjectResponse; +use Symfony\AI\Platform\Response\ResponseInterface as LlmResponse; +use Symfony\AI\Platform\Response\TextResponse; +use Symfony\AI\Platform\Response\VectorResponse; +use Symfony\AI\Platform\ResponseConverterInterface as PlatformResponseConverter; +use Symfony\AI\Platform\Vector\Vector; +use Symfony\Contracts\HttpClient\ResponseInterface; + +/** + * @author Christopher Hertel + */ +final readonly class ResponseConverter implements PlatformResponseConverter +{ + public function supports(Model $model): bool + { + return true; + } + + public function convert(ResponseInterface $response, array $options = []): LlmResponse + { + if (503 === $response->getStatusCode()) { + return throw new RuntimeException('Service unavailable.'); + } + + if (404 === $response->getStatusCode()) { + return throw new InvalidArgumentException('Model, provider or task not found (404).'); + } + + $headers = $response->getHeaders(false); + $contentType = $headers['content-type'][0] ?? null; + $content = 'application/json' === $contentType ? $response->toArray(false) : $response->getContent(false); + + if (str_starts_with((string) $response->getStatusCode(), '4')) { + $message = \is_string($content) ? $content : + (\is_array($content['error']) ? $content['error'][0] : $content['error']); + + throw new InvalidArgumentException(\sprintf('API Client Error (%d): %s', $response->getStatusCode(), $message)); + } + + if (200 !== $response->getStatusCode()) { + throw new RuntimeException('Unhandled response code: '.$response->getStatusCode()); + } + + $task = $options['task'] ?? null; + + return match ($task) { + Task::AUDIO_CLASSIFICATION, Task::IMAGE_CLASSIFICATION => new ObjectResponse( + ClassificationResult::fromArray($content) + ), + Task::AUTOMATIC_SPEECH_RECOGNITION => new TextResponse($content['text'] ?? ''), + Task::CHAT_COMPLETION => new TextResponse($content['choices'][0]['message']['content'] ?? ''), + Task::FEATURE_EXTRACTION => new VectorResponse(new Vector($content)), + Task::TEXT_CLASSIFICATION => new ObjectResponse(ClassificationResult::fromArray(reset($content) ?? [])), + Task::FILL_MASK => new ObjectResponse(FillMaskResult::fromArray($content)), + Task::IMAGE_SEGMENTATION => new ObjectResponse(ImageSegmentationResult::fromArray($content)), + Task::IMAGE_TO_TEXT, Task::TEXT_GENERATION => new TextResponse($content[0]['generated_text'] ?? ''), + Task::TEXT_TO_IMAGE => new BinaryResponse($content, $contentType), + Task::OBJECT_DETECTION => new ObjectResponse(ObjectDetectionResult::fromArray($content)), + Task::QUESTION_ANSWERING => new ObjectResponse(QuestionAnsweringResult::fromArray($content)), + Task::SENTENCE_SIMILARITY => new ObjectResponse(SentenceSimilarityResult::fromArray($content)), + Task::SUMMARIZATION => new TextResponse($content[0]['summary_text']), + Task::TABLE_QUESTION_ANSWERING => new ObjectResponse(TableQuestionAnsweringResult::fromArray($content)), + Task::TOKEN_CLASSIFICATION => new ObjectResponse(TokenClassificationResult::fromArray($content)), + Task::TRANSLATION => new TextResponse($content[0]['translation_text'] ?? ''), + Task::ZERO_SHOT_CLASSIFICATION => new ObjectResponse(ZeroShotClassificationResult::fromArray($content)), + + default => throw new RuntimeException(\sprintf('Unsupported task: %s', $task)), + }; + } +} diff --git a/src/platform/src/Bridge/HuggingFace/Task.php b/src/platform/src/Bridge/HuggingFace/Task.php new file mode 100644 index 000000000..e2be0049c --- /dev/null +++ b/src/platform/src/Bridge/HuggingFace/Task.php @@ -0,0 +1,38 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\HuggingFace; + +/** + * @author Christopher Hertel + */ +interface Task +{ + public const AUDIO_CLASSIFICATION = 'audio-classification'; + public const AUTOMATIC_SPEECH_RECOGNITION = 'automatic-speech-recognition'; + public const CHAT_COMPLETION = 'chat-completion'; + public const FEATURE_EXTRACTION = 'feature-extraction'; + public const FILL_MASK = 'fill-mask'; + public const IMAGE_CLASSIFICATION = 'image-classification'; + public const IMAGE_SEGMENTATION = 'image-segmentation'; + public const IMAGE_TO_TEXT = 'image-to-text'; + public const OBJECT_DETECTION = 'object-detection'; + public const QUESTION_ANSWERING = 'question-answering'; + public const SENTENCE_SIMILARITY = 'sentence-similarity'; + public const SUMMARIZATION = 'summarization'; + public const TABLE_QUESTION_ANSWERING = 'table-question-answering'; + public const TEXT_CLASSIFICATION = 'text-classification'; + public const TEXT_GENERATION = 'text-generation'; + public const TEXT_TO_IMAGE = 'text-to-image'; + public const TOKEN_CLASSIFICATION = 'token-classification'; + public const TRANSLATION = 'translation'; + public const ZERO_SHOT_CLASSIFICATION = 'zero-shot-classification'; +} diff --git a/src/platform/src/Bridge/Meta/Contract/MessageBagNormalizer.php b/src/platform/src/Bridge/Meta/Contract/MessageBagNormalizer.php new file mode 100644 index 000000000..95681e6ad --- /dev/null +++ b/src/platform/src/Bridge/Meta/Contract/MessageBagNormalizer.php @@ -0,0 +1,51 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Meta\Contract; + +use Symfony\AI\Platform\Bridge\Meta\Llama; +use Symfony\AI\Platform\Bridge\Meta\LlamaPromptConverter; +use Symfony\AI\Platform\Contract\Normalizer\ModelContractNormalizer; +use Symfony\AI\Platform\Message\MessageBagInterface; +use Symfony\AI\Platform\Model; + +/** + * @author Christopher Hertel + */ +final class MessageBagNormalizer extends ModelContractNormalizer +{ + public function __construct( + private readonly LlamaPromptConverter $promptConverter = new LlamaPromptConverter(), + ) { + } + + protected function supportedDataClass(): string + { + return MessageBagInterface::class; + } + + protected function supportsModel(Model $model): bool + { + return $model instanceof Llama; + } + + /** + * @param MessageBagInterface $data + * + * @return array{prompt: string} + */ + public function normalize(mixed $data, ?string $format = null, array $context = []): array + { + return [ + 'prompt' => $this->promptConverter->convertToPrompt($data), + ]; + } +} diff --git a/src/platform/src/Bridge/Meta/Llama.php b/src/platform/src/Bridge/Meta/Llama.php new file mode 100644 index 000000000..d6d50e9e0 --- /dev/null +++ b/src/platform/src/Bridge/Meta/Llama.php @@ -0,0 +1,50 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Meta; + +use Symfony\AI\Platform\Capability; +use Symfony\AI\Platform\Model; + +/** + * @author Christopher Hertel + */ +class Llama extends Model +{ + public const V3_3_70B_INSTRUCT = 'llama-3.3-70B-Instruct'; + public const V3_2_90B_VISION_INSTRUCT = 'llama-3.2-90b-vision-instruct'; + public const V3_2_11B_VISION_INSTRUCT = 'llama-3.2-11b-vision-instruct'; + public const V3_2_3B = 'llama-3.2-3b'; + public const V3_2_3B_INSTRUCT = 'llama-3.2-3b-instruct'; + public const V3_2_1B = 'llama-3.2-1b'; + public const V3_2_1B_INSTRUCT = 'llama-3.2-1b-instruct'; + public const V3_1_405B_INSTRUCT = 'llama-3.1-405b-instruct'; + public const V3_1_70B = 'llama-3.1-70b'; + public const V3_1_70B_INSTRUCT = 'llama-3-70b-instruct'; + public const V3_1_8B = 'llama-3.1-8b'; + public const V3_1_8B_INSTRUCT = 'llama-3.1-8b-instruct'; + public const V3_70B = 'llama-3-70b'; + public const V3_8B_INSTRUCT = 'llama-3-8b-instruct'; + public const V3_8B = 'llama-3-8b'; + + /** + * @param array $options + */ + public function __construct(string $name = self::V3_1_405B_INSTRUCT, array $options = []) + { + $capabilities = [ + Capability::INPUT_MESSAGES, + Capability::OUTPUT_TEXT, + ]; + + parent::__construct($name, $capabilities, $options); + } +} diff --git a/src/platform/src/Bridge/Meta/LlamaPromptConverter.php b/src/platform/src/Bridge/Meta/LlamaPromptConverter.php new file mode 100644 index 000000000..c481c2590 --- /dev/null +++ b/src/platform/src/Bridge/Meta/LlamaPromptConverter.php @@ -0,0 +1,98 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Meta; + +use Symfony\AI\Platform\Exception\RuntimeException; +use Symfony\AI\Platform\Message\AssistantMessage; +use Symfony\AI\Platform\Message\Content\ImageUrl; +use Symfony\AI\Platform\Message\Content\Text; +use Symfony\AI\Platform\Message\MessageBagInterface; +use Symfony\AI\Platform\Message\SystemMessage; +use Symfony\AI\Platform\Message\UserMessage; + +/** + * @author Oskar Stark + */ +final class LlamaPromptConverter +{ + public function convertToPrompt(MessageBagInterface $messageBag): string + { + $messages = []; + + /** @var UserMessage|SystemMessage|AssistantMessage $message */ + foreach ($messageBag->getMessages() as $message) { + $messages[] = self::convertMessage($message); + } + + $messages = array_filter($messages, fn ($message) => '' !== $message); + + return trim(implode(\PHP_EOL.\PHP_EOL, $messages)).\PHP_EOL.\PHP_EOL.'<|start_header_id|>assistant<|end_header_id|>'; + } + + public function convertMessage(UserMessage|SystemMessage|AssistantMessage $message): string + { + if ($message instanceof SystemMessage) { + return trim(<<<|start_header_id|>system<|end_header_id|> + + {$message->content}<|eot_id|> + SYSTEM); + } + + if ($message instanceof AssistantMessage) { + if ('' === $message->content || null === $message->content) { + return ''; + } + + return trim(<<{$message->getRole()->value}<|end_header_id|> + + {$message->content}<|eot_id|> + ASSISTANT); + } + + // Handling of UserMessage + $count = \count($message->content); + + $contentParts = []; + if ($count > 1) { + foreach ($message->content as $value) { + if ($value instanceof Text) { + $contentParts[] = $value->text; + } + + if ($value instanceof ImageUrl) { + $contentParts[] = $value->url; + } + } + } elseif (1 === $count) { + $value = $message->content[0]; + if ($value instanceof Text) { + $contentParts[] = $value->text; + } + + if ($value instanceof ImageUrl) { + $contentParts[] = $value->url; + } + } else { + throw new RuntimeException('Unsupported message type.'); + } + + $content = implode(\PHP_EOL, $contentParts); + + return trim(<<{$message->getRole()->value}<|end_header_id|> + + {$content}<|eot_id|> + USER); + } +} diff --git a/src/platform/src/Bridge/Mistral/Contract/ToolNormalizer.php b/src/platform/src/Bridge/Mistral/Contract/ToolNormalizer.php new file mode 100644 index 000000000..d1b47d756 --- /dev/null +++ b/src/platform/src/Bridge/Mistral/Contract/ToolNormalizer.php @@ -0,0 +1,29 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Mistral\Contract; + +use Symfony\AI\Platform\Contract\Normalizer\ToolNormalizer as BaseToolNormalizer; + +/** + * @author Christopher Hertel + */ +class ToolNormalizer extends BaseToolNormalizer +{ + public function normalize(mixed $data, ?string $format = null, array $context = []): array + { + $array = parent::normalize($data, $format, $context); + + $array['function']['parameters'] ??= ['type' => 'object']; + + return $array; + } +} diff --git a/src/platform/src/Bridge/Mistral/Embeddings.php b/src/platform/src/Bridge/Mistral/Embeddings.php new file mode 100644 index 000000000..6f96f5f48 --- /dev/null +++ b/src/platform/src/Bridge/Mistral/Embeddings.php @@ -0,0 +1,33 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Mistral; + +use Symfony\AI\Platform\Capability; +use Symfony\AI\Platform\Model; + +/** + * @author Christopher Hertel + */ +final class Embeddings extends Model +{ + public const MISTRAL_EMBED = 'mistral-embed'; + + /** + * @param array $options + */ + public function __construct( + string $name = self::MISTRAL_EMBED, + array $options = [], + ) { + parent::__construct($name, [Capability::INPUT_MULTIPLE], $options); + } +} diff --git a/src/platform/src/Bridge/Mistral/Embeddings/ModelClient.php b/src/platform/src/Bridge/Mistral/Embeddings/ModelClient.php new file mode 100644 index 000000000..ec6b06baf --- /dev/null +++ b/src/platform/src/Bridge/Mistral/Embeddings/ModelClient.php @@ -0,0 +1,54 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Mistral\Embeddings; + +use Symfony\AI\Platform\Bridge\Mistral\Embeddings; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\ModelClientInterface; +use Symfony\Component\HttpClient\EventSourceHttpClient; +use Symfony\Contracts\HttpClient\HttpClientInterface; +use Symfony\Contracts\HttpClient\ResponseInterface; + +/** + * @author Christopher Hertel + */ +final readonly class ModelClient implements ModelClientInterface +{ + private EventSourceHttpClient $httpClient; + + public function __construct( + HttpClientInterface $httpClient, + #[\SensitiveParameter] + private string $apiKey, + ) { + $this->httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient); + } + + public function supports(Model $model): bool + { + return $model instanceof Embeddings; + } + + public function request(Model $model, array|string $payload, array $options = []): ResponseInterface + { + return $this->httpClient->request('POST', 'https://api.mistral.ai/v1/embeddings', [ + 'auth_bearer' => $this->apiKey, + 'headers' => [ + 'Content-Type' => 'application/json', + ], + 'json' => array_merge($options, [ + 'model' => $model->getName(), + 'input' => $payload, + ]), + ]); + } +} diff --git a/src/platform/src/Bridge/Mistral/Embeddings/ResponseConverter.php b/src/platform/src/Bridge/Mistral/Embeddings/ResponseConverter.php new file mode 100644 index 000000000..cb934d7de --- /dev/null +++ b/src/platform/src/Bridge/Mistral/Embeddings/ResponseConverter.php @@ -0,0 +1,51 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Mistral\Embeddings; + +use Symfony\AI\Platform\Bridge\Mistral\Embeddings; +use Symfony\AI\Platform\Exception\RuntimeException; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\Response\VectorResponse; +use Symfony\AI\Platform\ResponseConverterInterface; +use Symfony\AI\Platform\Vector\Vector; +use Symfony\Contracts\HttpClient\ResponseInterface; + +/** + * @author Christopher Hertel + */ +final readonly class ResponseConverter implements ResponseConverterInterface +{ + public function supports(Model $model): bool + { + return $model instanceof Embeddings; + } + + public function convert(ResponseInterface $response, array $options = []): VectorResponse + { + $data = $response->toArray(false); + + if (200 !== $response->getStatusCode()) { + throw new RuntimeException(\sprintf('Unexpected response code %d: %s', $response->getStatusCode(), $response->getContent(false))); + } + + if (!isset($data['data'])) { + throw new RuntimeException('Response does not contain data'); + } + + return new VectorResponse( + ...array_map( + static fn (array $item): Vector => new Vector($item['embedding']), + $data['data'] + ), + ); + } +} diff --git a/src/platform/src/Bridge/Mistral/Llm/ModelClient.php b/src/platform/src/Bridge/Mistral/Llm/ModelClient.php new file mode 100644 index 000000000..93924cdc8 --- /dev/null +++ b/src/platform/src/Bridge/Mistral/Llm/ModelClient.php @@ -0,0 +1,52 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Mistral\Llm; + +use Symfony\AI\Platform\Bridge\Mistral\Mistral; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\ModelClientInterface; +use Symfony\Component\HttpClient\EventSourceHttpClient; +use Symfony\Contracts\HttpClient\HttpClientInterface; +use Symfony\Contracts\HttpClient\ResponseInterface; + +/** + * @author Christopher Hertel + */ +final readonly class ModelClient implements ModelClientInterface +{ + private EventSourceHttpClient $httpClient; + + public function __construct( + HttpClientInterface $httpClient, + #[\SensitiveParameter] + private string $apiKey, + ) { + $this->httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient); + } + + public function supports(Model $model): bool + { + return $model instanceof Mistral; + } + + public function request(Model $model, array|string $payload, array $options = []): ResponseInterface + { + return $this->httpClient->request('POST', 'https://api.mistral.ai/v1/chat/completions', [ + 'auth_bearer' => $this->apiKey, + 'headers' => [ + 'Content-Type' => 'application/json', + 'Accept' => 'application/json', + ], + 'json' => array_merge($options, $payload), + ]); + } +} diff --git a/src/platform/src/Bridge/Mistral/Llm/ResponseConverter.php b/src/platform/src/Bridge/Mistral/Llm/ResponseConverter.php new file mode 100644 index 000000000..7ca164b73 --- /dev/null +++ b/src/platform/src/Bridge/Mistral/Llm/ResponseConverter.php @@ -0,0 +1,199 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Mistral\Llm; + +use Symfony\AI\Platform\Bridge\Mistral\Mistral; +use Symfony\AI\Platform\Exception\RuntimeException; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\Response\Choice; +use Symfony\AI\Platform\Response\ChoiceResponse; +use Symfony\AI\Platform\Response\ResponseInterface as LlmResponse; +use Symfony\AI\Platform\Response\StreamResponse; +use Symfony\AI\Platform\Response\TextResponse; +use Symfony\AI\Platform\Response\ToolCall; +use Symfony\AI\Platform\Response\ToolCallResponse; +use Symfony\AI\Platform\ResponseConverterInterface; +use Symfony\Component\HttpClient\Chunk\ServerSentEvent; +use Symfony\Component\HttpClient\EventSourceHttpClient; +use Symfony\Component\HttpClient\Exception\JsonException; +use Symfony\Contracts\HttpClient\ResponseInterface as HttpResponse; + +/** + * @author Christopher Hertel + */ +final readonly class ResponseConverter implements ResponseConverterInterface +{ + public function supports(Model $model): bool + { + return $model instanceof Mistral; + } + + /** + * @param array $options + */ + public function convert(HttpResponse $response, array $options = []): LlmResponse + { + if ($options['stream'] ?? false) { + return new StreamResponse($this->convertStream($response)); + } + + $code = $response->getStatusCode(); + $data = $response->toArray(false); + + if (200 !== $code) { + throw new RuntimeException(\sprintf('Unexpected response code %d: %s', $code, $response->getContent(false))); + } + + if (!isset($data['choices'])) { + throw new RuntimeException('Response does not contain choices'); + } + + /** @var Choice[] $choices */ + $choices = array_map($this->convertChoice(...), $data['choices']); + + if (1 !== \count($choices)) { + return new ChoiceResponse(...$choices); + } + + if ($choices[0]->hasToolCall()) { + return new ToolCallResponse(...$choices[0]->getToolCalls()); + } + + return new TextResponse($choices[0]->getContent()); + } + + private function convertStream(HttpResponse $response): \Generator + { + $toolCalls = []; + foreach ((new EventSourceHttpClient())->stream($response) as $chunk) { + if (!$chunk instanceof ServerSentEvent || '[DONE]' === $chunk->getData()) { + continue; + } + + try { + $data = $chunk->getArrayData(); + } catch (JsonException) { + // try catch only needed for Symfony 6.4 + continue; + } + + if ($this->streamIsToolCall($data)) { + $toolCalls = $this->convertStreamToToolCalls($toolCalls, $data); + } + + if ([] !== $toolCalls && $this->isToolCallsStreamFinished($data)) { + yield new ToolCallResponse(...array_map($this->convertToolCall(...), $toolCalls)); + } + + if (!isset($data['choices'][0]['delta']['content'])) { + continue; + } + + yield $data['choices'][0]['delta']['content']; + } + } + + /** + * @param array $toolCalls + * @param array $data + * + * @return array + */ + private function convertStreamToToolCalls(array $toolCalls, array $data): array + { + if (!isset($data['choices'][0]['delta']['tool_calls'])) { + return $toolCalls; + } + + foreach ($data['choices'][0]['delta']['tool_calls'] as $i => $toolCall) { + if (isset($toolCall['id'])) { + // initialize tool call + $toolCalls[$i] = [ + 'id' => $toolCall['id'], + 'function' => $toolCall['function'], + ]; + continue; + } + + // add arguments delta to tool call + $toolCalls[$i]['function']['arguments'] .= $toolCall['function']['arguments']; + } + + return $toolCalls; + } + + /** + * @param array $data + */ + private function streamIsToolCall(array $data): bool + { + return isset($data['choices'][0]['delta']['tool_calls']); + } + + /** + * @param array $data + */ + private function isToolCallsStreamFinished(array $data): bool + { + return isset($data['choices'][0]['finish_reason']) && 'tool_calls' === $data['choices'][0]['finish_reason']; + } + + /** + * @param array{ + * index: integer, + * message: array{ + * role: 'assistant', + * content: ?string, + * tool_calls: array{ + * id: string, + * type: 'function', + * function: array{ + * name: string, + * arguments: string + * }, + * }, + * refusal: ?mixed + * }, + * logprobs: string, + * finish_reason: 'stop'|'length'|'tool_calls'|'content_filter', + * } $choice + */ + private function convertChoice(array $choice): Choice + { + if ('tool_calls' === $choice['finish_reason']) { + return new Choice(toolCalls: array_map([$this, 'convertToolCall'], $choice['message']['tool_calls'])); + } + + if ('stop' === $choice['finish_reason']) { + return new Choice($choice['message']['content']); + } + + throw new RuntimeException(\sprintf('Unsupported finish reason "%s".', $choice['finish_reason'])); + } + + /** + * @param array{ + * id: string, + * type: 'function', + * function: array{ + * name: string, + * arguments: string + * } + * } $toolCall + */ + private function convertToolCall(array $toolCall): ToolCall + { + $arguments = json_decode((string) $toolCall['function']['arguments'], true, \JSON_THROW_ON_ERROR); + + return new ToolCall($toolCall['id'], $toolCall['function']['name'], $arguments); + } +} diff --git a/src/platform/src/Bridge/Mistral/Mistral.php b/src/platform/src/Bridge/Mistral/Mistral.php new file mode 100644 index 000000000..62fec67e4 --- /dev/null +++ b/src/platform/src/Bridge/Mistral/Mistral.php @@ -0,0 +1,66 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Mistral; + +use Symfony\AI\Platform\Capability; +use Symfony\AI\Platform\Model; + +/** + * @author Christopher Hertel + */ +final class Mistral extends Model +{ + public const CODESTRAL = 'codestral-latest'; + public const CODESTRAL_MAMBA = 'open-codestral-mamba'; + public const MISTRAL_LARGE = 'mistral-large-latest'; + public const MISTRAL_SMALL = 'mistral-small-latest'; + public const MISTRAL_NEMO = 'open-mistral-nemo'; + public const MISTRAL_SABA = 'mistral-saba-latest'; + public const MINISTRAL_3B = 'mistral-3b-latest'; + public const MINISTRAL_8B = 'mistral-8b-latest'; + public const PIXSTRAL_LARGE = 'pixstral-large-latest'; + public const PIXSTRAL = 'pixstral-12b-latest'; + + /** + * @param array $options + */ + public function __construct( + string $name = self::MISTRAL_LARGE, + array $options = [], + ) { + $capabilities = [ + Capability::INPUT_MESSAGES, + Capability::OUTPUT_TEXT, + Capability::OUTPUT_STREAMING, + Capability::OUTPUT_STRUCTURED, + ]; + + if (\in_array($name, [self::PIXSTRAL, self::PIXSTRAL_LARGE, self::MISTRAL_SMALL], true)) { + $capabilities[] = Capability::INPUT_IMAGE; + } + + if (\in_array($name, [ + self::CODESTRAL, + self::MISTRAL_LARGE, + self::MISTRAL_SMALL, + self::MISTRAL_NEMO, + self::MINISTRAL_3B, + self::MINISTRAL_8B, + self::PIXSTRAL, + self::PIXSTRAL_LARGE, + ], true)) { + $capabilities[] = Capability::TOOL_CALLING; + } + + parent::__construct($name, $capabilities, $options); + } +} diff --git a/src/platform/src/Bridge/Mistral/PlatformFactory.php b/src/platform/src/Bridge/Mistral/PlatformFactory.php new file mode 100644 index 000000000..0b5ccabbf --- /dev/null +++ b/src/platform/src/Bridge/Mistral/PlatformFactory.php @@ -0,0 +1,42 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Mistral; + +use Symfony\AI\Platform\Bridge\Mistral\Contract\ToolNormalizer; +use Symfony\AI\Platform\Bridge\Mistral\Embeddings\ModelClient as EmbeddingsModelClient; +use Symfony\AI\Platform\Bridge\Mistral\Embeddings\ResponseConverter as EmbeddingsResponseConverter; +use Symfony\AI\Platform\Bridge\Mistral\Llm\ModelClient as MistralModelClient; +use Symfony\AI\Platform\Bridge\Mistral\Llm\ResponseConverter as MistralResponseConverter; +use Symfony\AI\Platform\Contract; +use Symfony\AI\Platform\Platform; +use Symfony\Component\HttpClient\EventSourceHttpClient; +use Symfony\Contracts\HttpClient\HttpClientInterface; + +/** + * @author Christopher Hertel + */ +final class PlatformFactory +{ + public static function create( + #[\SensitiveParameter] + string $apiKey, + ?HttpClientInterface $httpClient = null, + ): Platform { + $httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient); + + return new Platform( + [new EmbeddingsModelClient($httpClient, $apiKey), new MistralModelClient($httpClient, $apiKey)], + [new EmbeddingsResponseConverter(), new MistralResponseConverter()], + Contract::create(new ToolNormalizer()), + ); + } +} diff --git a/src/platform/src/Bridge/Ollama/LlamaModelHandler.php b/src/platform/src/Bridge/Ollama/LlamaModelHandler.php new file mode 100644 index 000000000..56b0d05c9 --- /dev/null +++ b/src/platform/src/Bridge/Ollama/LlamaModelHandler.php @@ -0,0 +1,65 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Ollama; + +use Symfony\AI\Platform\Bridge\Meta\Llama; +use Symfony\AI\Platform\Exception\RuntimeException; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\ModelClientInterface; +use Symfony\AI\Platform\Response\ResponseInterface as LlmResponse; +use Symfony\AI\Platform\Response\TextResponse; +use Symfony\AI\Platform\ResponseConverterInterface; +use Symfony\Contracts\HttpClient\HttpClientInterface; +use Symfony\Contracts\HttpClient\ResponseInterface; + +/** + * @author Christopher Hertel + */ +final readonly class LlamaModelHandler implements ModelClientInterface, ResponseConverterInterface +{ + public function __construct( + private HttpClientInterface $httpClient, + private string $hostUrl, + ) { + } + + public function supports(Model $model): bool + { + return $model instanceof Llama; + } + + public function request(Model $model, array|string $payload, array $options = []): ResponseInterface + { + // Revert Ollama's default streaming behavior + $options['stream'] ??= false; + + return $this->httpClient->request('POST', \sprintf('%s/api/chat', $this->hostUrl), [ + 'headers' => ['Content-Type' => 'application/json'], + 'json' => array_merge($options, $payload), + ]); + } + + public function convert(ResponseInterface $response, array $options = []): LlmResponse + { + $data = $response->toArray(); + + if (!isset($data['message'])) { + throw new RuntimeException('Response does not contain message'); + } + + if (!isset($data['message']['content'])) { + throw new RuntimeException('Message does not contain content'); + } + + return new TextResponse($data['message']['content']); + } +} diff --git a/src/platform/src/Bridge/Ollama/PlatformFactory.php b/src/platform/src/Bridge/Ollama/PlatformFactory.php new file mode 100644 index 000000000..fdde43e38 --- /dev/null +++ b/src/platform/src/Bridge/Ollama/PlatformFactory.php @@ -0,0 +1,32 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Ollama; + +use Symfony\AI\Platform\Platform; +use Symfony\Component\HttpClient\EventSourceHttpClient; +use Symfony\Contracts\HttpClient\HttpClientInterface; + +/** + * @author Christopher Hertel + */ +final class PlatformFactory +{ + public static function create( + string $hostUrl = 'http://localhost:11434', + ?HttpClientInterface $httpClient = null, + ): Platform { + $httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient); + $handler = new LlamaModelHandler($httpClient, $hostUrl); + + return new Platform([$handler], [$handler]); + } +} diff --git a/src/platform/src/Bridge/OpenAI/DallE.php b/src/platform/src/Bridge/OpenAI/DallE.php new file mode 100644 index 000000000..b58b3a1b6 --- /dev/null +++ b/src/platform/src/Bridge/OpenAI/DallE.php @@ -0,0 +1,35 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\OpenAI; + +use Symfony\AI\Platform\Capability; +use Symfony\AI\Platform\Model; + +/** + * @author Denis Zunke + */ +class DallE extends Model +{ + public const DALL_E_2 = 'dall-e-2'; + public const DALL_E_3 = 'dall-e-3'; + + /** @param array $options The default options for the model usage */ + public function __construct(string $name = self::DALL_E_2, array $options = []) + { + $capabilities = [ + Capability::INPUT_TEXT, + Capability::OUTPUT_IMAGE, + ]; + + parent::__construct($name, $capabilities, $options); + } +} diff --git a/src/platform/src/Bridge/OpenAI/DallE/Base64Image.php b/src/platform/src/Bridge/OpenAI/DallE/Base64Image.php new file mode 100644 index 000000000..18d262853 --- /dev/null +++ b/src/platform/src/Bridge/OpenAI/DallE/Base64Image.php @@ -0,0 +1,26 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\OpenAI\DallE; + +use Webmozart\Assert\Assert; + +/** + * @author Denis Zunke + */ +final readonly class Base64Image +{ + public function __construct( + public string $encodedImage, + ) { + Assert::stringNotEmpty($encodedImage, 'The base64 encoded image generated must be given.'); + } +} diff --git a/src/platform/src/Bridge/OpenAI/DallE/ImageResponse.php b/src/platform/src/Bridge/OpenAI/DallE/ImageResponse.php new file mode 100644 index 000000000..8a7f16046 --- /dev/null +++ b/src/platform/src/Bridge/OpenAI/DallE/ImageResponse.php @@ -0,0 +1,38 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\OpenAI\DallE; + +use Symfony\AI\Platform\Response\BaseResponse; + +/** + * @author Denis Zunke + */ +class ImageResponse extends BaseResponse +{ + /** @var list */ + private readonly array $images; + + public function __construct( + public ?string $revisedPrompt = null, // Only string on Dall-E 3 usage + Base64Image|UrlImage ...$images, + ) { + $this->images = array_values($images); + } + + /** + * @return list + */ + public function getContent(): array + { + return $this->images; + } +} diff --git a/src/platform/src/Bridge/OpenAI/DallE/ModelClient.php b/src/platform/src/Bridge/OpenAI/DallE/ModelClient.php new file mode 100644 index 000000000..91abadcdb --- /dev/null +++ b/src/platform/src/Bridge/OpenAI/DallE/ModelClient.php @@ -0,0 +1,76 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\OpenAI\DallE; + +use Symfony\AI\Platform\Bridge\OpenAI\DallE; +use Symfony\AI\Platform\Exception\RuntimeException; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\ModelClientInterface as PlatformResponseFactory; +use Symfony\AI\Platform\Response\ResponseInterface as LlmResponse; +use Symfony\AI\Platform\ResponseConverterInterface as PlatformResponseConverter; +use Symfony\Contracts\HttpClient\HttpClientInterface; +use Symfony\Contracts\HttpClient\ResponseInterface as HttpResponse; +use Webmozart\Assert\Assert; + +/** + * @see https://platform.openai.com/docs/api-reference/images/create + * + * @author Denis Zunke + */ +final readonly class ModelClient implements PlatformResponseFactory, PlatformResponseConverter +{ + public function __construct( + private HttpClientInterface $httpClient, + #[\SensitiveParameter] + private string $apiKey, + ) { + Assert::stringNotEmpty($apiKey, 'The API key must not be empty.'); + Assert::startsWith($apiKey, 'sk-', 'The API key must start with "sk-".'); + } + + public function supports(Model $model): bool + { + return $model instanceof DallE; + } + + public function request(Model $model, array|string $payload, array $options = []): HttpResponse + { + return $this->httpClient->request('POST', 'https://api.openai.com/v1/images/generations', [ + 'auth_bearer' => $this->apiKey, + 'json' => array_merge($options, [ + 'model' => $model->getName(), + 'prompt' => $payload, + ]), + ]); + } + + public function convert(HttpResponse $response, array $options = []): LlmResponse + { + $response = $response->toArray(); + if (!isset($response['data'][0])) { + throw new RuntimeException('No image generated.'); + } + + $images = []; + foreach ($response['data'] as $image) { + if ('url' === $options['response_format']) { + $images[] = new UrlImage($image['url']); + + continue; + } + + $images[] = new Base64Image($image['b64_json']); + } + + return new ImageResponse($image['revised_prompt'] ?? null, ...$images); + } +} diff --git a/src/platform/src/Bridge/OpenAI/DallE/UrlImage.php b/src/platform/src/Bridge/OpenAI/DallE/UrlImage.php new file mode 100644 index 000000000..016c13604 --- /dev/null +++ b/src/platform/src/Bridge/OpenAI/DallE/UrlImage.php @@ -0,0 +1,26 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\OpenAI\DallE; + +use Webmozart\Assert\Assert; + +/** + * @author Denis Zunke + */ +final readonly class UrlImage +{ + public function __construct( + public string $url, + ) { + Assert::stringNotEmpty($url, 'The image url must be given.'); + } +} diff --git a/src/platform/src/Bridge/OpenAI/Embeddings.php b/src/platform/src/Bridge/OpenAI/Embeddings.php new file mode 100644 index 000000000..907aa897e --- /dev/null +++ b/src/platform/src/Bridge/OpenAI/Embeddings.php @@ -0,0 +1,32 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\OpenAI; + +use Symfony\AI\Platform\Model; + +/** + * @author Christopher Hertel + */ +class Embeddings extends Model +{ + public const TEXT_ADA_002 = 'text-embedding-ada-002'; + public const TEXT_3_LARGE = 'text-embedding-3-large'; + public const TEXT_3_SMALL = 'text-embedding-3-small'; + + /** + * @param array $options + */ + public function __construct(string $name = self::TEXT_3_SMALL, array $options = []) + { + parent::__construct($name, [], $options); + } +} diff --git a/src/platform/src/Bridge/OpenAI/Embeddings/ModelClient.php b/src/platform/src/Bridge/OpenAI/Embeddings/ModelClient.php new file mode 100644 index 000000000..72866920b --- /dev/null +++ b/src/platform/src/Bridge/OpenAI/Embeddings/ModelClient.php @@ -0,0 +1,50 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\OpenAI\Embeddings; + +use Symfony\AI\Platform\Bridge\OpenAI\Embeddings; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\ModelClientInterface as PlatformResponseFactory; +use Symfony\Contracts\HttpClient\HttpClientInterface; +use Symfony\Contracts\HttpClient\ResponseInterface; +use Webmozart\Assert\Assert; + +/** + * @author Christopher Hertel + */ +final readonly class ModelClient implements PlatformResponseFactory +{ + public function __construct( + private HttpClientInterface $httpClient, + #[\SensitiveParameter] + private string $apiKey, + ) { + Assert::stringNotEmpty($apiKey, 'The API key must not be empty.'); + Assert::startsWith($apiKey, 'sk-', 'The API key must start with "sk-".'); + } + + public function supports(Model $model): bool + { + return $model instanceof Embeddings; + } + + public function request(Model $model, array|string $payload, array $options = []): ResponseInterface + { + return $this->httpClient->request('POST', 'https://api.openai.com/v1/embeddings', [ + 'auth_bearer' => $this->apiKey, + 'json' => array_merge($options, [ + 'model' => $model->getName(), + 'input' => $payload, + ]), + ]); + } +} diff --git a/src/platform/src/Bridge/OpenAI/Embeddings/ResponseConverter.php b/src/platform/src/Bridge/OpenAI/Embeddings/ResponseConverter.php new file mode 100644 index 000000000..caf471056 --- /dev/null +++ b/src/platform/src/Bridge/OpenAI/Embeddings/ResponseConverter.php @@ -0,0 +1,47 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\OpenAI\Embeddings; + +use Symfony\AI\Platform\Bridge\OpenAI\Embeddings; +use Symfony\AI\Platform\Exception\RuntimeException; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\Response\VectorResponse; +use Symfony\AI\Platform\ResponseConverterInterface as PlatformResponseConverter; +use Symfony\AI\Platform\Vector\Vector; +use Symfony\Contracts\HttpClient\ResponseInterface; + +/** + * @author Christopher Hertel + */ +final class ResponseConverter implements PlatformResponseConverter +{ + public function supports(Model $model): bool + { + return $model instanceof Embeddings; + } + + public function convert(ResponseInterface $response, array $options = []): VectorResponse + { + $data = $response->toArray(); + + if (!isset($data['data'])) { + throw new RuntimeException('Response does not contain data'); + } + + return new VectorResponse( + ...array_map( + static fn (array $item): Vector => new Vector($item['embedding']), + $data['data'] + ), + ); + } +} diff --git a/src/platform/src/Bridge/OpenAI/GPT.php b/src/platform/src/Bridge/OpenAI/GPT.php new file mode 100644 index 000000000..1e36bd765 --- /dev/null +++ b/src/platform/src/Bridge/OpenAI/GPT.php @@ -0,0 +1,90 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\OpenAI; + +use Symfony\AI\Platform\Capability; +use Symfony\AI\Platform\Model; + +/** + * @author Christopher Hertel + * @author Oskar Stark + */ +class GPT extends Model +{ + public const GPT_35_TURBO = 'gpt-3.5-turbo'; + public const GPT_35_TURBO_INSTRUCT = 'gpt-3.5-turbo-instruct'; + public const GPT_4 = 'gpt-4'; + public const GPT_4_TURBO = 'gpt-4-turbo'; + public const GPT_4O = 'gpt-4o'; + public const GPT_4O_MINI = 'gpt-4o-mini'; + public const GPT_4O_AUDIO = 'gpt-4o-audio-preview'; + public const O1_MINI = 'o1-mini'; + public const O1_PREVIEW = 'o1-preview'; + public const O3_MINI = 'o3-mini'; + public const O3_MINI_HIGH = 'o3-mini-high'; + public const GPT_45_PREVIEW = 'gpt-4.5-preview'; + public const GPT_41 = 'gpt-4.1'; + public const GPT_41_MINI = 'gpt-4.1-mini'; + public const GPT_41_NANO = 'gpt-4.1-nano'; + + private const IMAGE_SUPPORTING = [ + self::GPT_4_TURBO, + self::GPT_4O, + self::GPT_4O_MINI, + self::O1_MINI, + self::O1_PREVIEW, + self::O3_MINI, + self::GPT_45_PREVIEW, + self::GPT_41, + self::GPT_41_MINI, + self::GPT_41_NANO, + ]; + + private const STRUCTURED_OUTPUT_SUPPORTING = [ + self::GPT_4O, + self::GPT_4O_MINI, + self::O3_MINI, + self::GPT_45_PREVIEW, + self::GPT_41, + self::GPT_41_MINI, + self::GPT_41_NANO, + ]; + + /** + * @param array $options The default options for the model usage + */ + public function __construct( + string $name = self::GPT_4O, + array $options = ['temperature' => 1.0], + ) { + $capabilities = [ + Capability::INPUT_MESSAGES, + Capability::OUTPUT_TEXT, + Capability::OUTPUT_STREAMING, + Capability::TOOL_CALLING, + ]; + + if (self::GPT_4O_AUDIO === $name) { + $capabilities[] = Capability::INPUT_AUDIO; + } + + if (\in_array($name, self::IMAGE_SUPPORTING, true)) { + $capabilities[] = Capability::INPUT_IMAGE; + } + + if (\in_array($name, self::STRUCTURED_OUTPUT_SUPPORTING, true)) { + $capabilities[] = Capability::OUTPUT_STRUCTURED; + } + + parent::__construct($name, $capabilities, $options); + } +} diff --git a/src/platform/src/Bridge/OpenAI/GPT/ModelClient.php b/src/platform/src/Bridge/OpenAI/GPT/ModelClient.php new file mode 100644 index 000000000..ef2c6ffea --- /dev/null +++ b/src/platform/src/Bridge/OpenAI/GPT/ModelClient.php @@ -0,0 +1,51 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\OpenAI\GPT; + +use Symfony\AI\Platform\Bridge\OpenAI\GPT; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\ModelClientInterface as PlatformResponseFactory; +use Symfony\Component\HttpClient\EventSourceHttpClient; +use Symfony\Contracts\HttpClient\HttpClientInterface; +use Symfony\Contracts\HttpClient\ResponseInterface; +use Webmozart\Assert\Assert; + +/** + * @author Christopher Hertel + */ +final readonly class ModelClient implements PlatformResponseFactory +{ + private EventSourceHttpClient $httpClient; + + public function __construct( + HttpClientInterface $httpClient, + #[\SensitiveParameter] + private string $apiKey, + ) { + $this->httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient); + Assert::stringNotEmpty($apiKey, 'The API key must not be empty.'); + Assert::startsWith($apiKey, 'sk-', 'The API key must start with "sk-".'); + } + + public function supports(Model $model): bool + { + return $model instanceof GPT; + } + + public function request(Model $model, array|string $payload, array $options = []): ResponseInterface + { + return $this->httpClient->request('POST', 'https://api.openai.com/v1/chat/completions', [ + 'auth_bearer' => $this->apiKey, + 'json' => array_merge($options, $payload), + ]); + } +} diff --git a/src/platform/src/Bridge/OpenAI/GPT/ResponseConverter.php b/src/platform/src/Bridge/OpenAI/GPT/ResponseConverter.php new file mode 100644 index 000000000..ca43dfb35 --- /dev/null +++ b/src/platform/src/Bridge/OpenAI/GPT/ResponseConverter.php @@ -0,0 +1,204 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\OpenAI\GPT; + +use Symfony\AI\Platform\Bridge\OpenAI\GPT; +use Symfony\AI\Platform\Exception\ContentFilterException; +use Symfony\AI\Platform\Exception\RuntimeException; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\Response\Choice; +use Symfony\AI\Platform\Response\ChoiceResponse; +use Symfony\AI\Platform\Response\ResponseInterface as LlmResponse; +use Symfony\AI\Platform\Response\StreamResponse; +use Symfony\AI\Platform\Response\TextResponse; +use Symfony\AI\Platform\Response\ToolCall; +use Symfony\AI\Platform\Response\ToolCallResponse; +use Symfony\AI\Platform\ResponseConverterInterface as PlatformResponseConverter; +use Symfony\Component\HttpClient\Chunk\ServerSentEvent; +use Symfony\Component\HttpClient\EventSourceHttpClient; +use Symfony\Component\HttpClient\Exception\JsonException; +use Symfony\Contracts\HttpClient\Exception\ClientExceptionInterface; +use Symfony\Contracts\HttpClient\ResponseInterface as HttpResponse; + +/** + * @author Christopher Hertel + * @author Denis Zunke + */ +final class ResponseConverter implements PlatformResponseConverter +{ + public function supports(Model $model): bool + { + return $model instanceof GPT; + } + + public function convert(HttpResponse $response, array $options = []): LlmResponse + { + if ($options['stream'] ?? false) { + return new StreamResponse($this->convertStream($response)); + } + + try { + $data = $response->toArray(); + } catch (ClientExceptionInterface $e) { + $data = $response->toArray(throw: false); + + if (isset($data['error']['code']) && 'content_filter' === $data['error']['code']) { + throw new ContentFilterException(message: $data['error']['message'], previous: $e); + } + + throw $e; + } + + if (!isset($data['choices'])) { + throw new RuntimeException('Response does not contain choices'); + } + + /** @var Choice[] $choices */ + $choices = array_map($this->convertChoice(...), $data['choices']); + + if (1 !== \count($choices)) { + return new ChoiceResponse(...$choices); + } + + if ($choices[0]->hasToolCall()) { + return new ToolCallResponse(...$choices[0]->getToolCalls()); + } + + return new TextResponse($choices[0]->getContent()); + } + + private function convertStream(HttpResponse $response): \Generator + { + $toolCalls = []; + foreach ((new EventSourceHttpClient())->stream($response) as $chunk) { + if (!$chunk instanceof ServerSentEvent || '[DONE]' === $chunk->getData()) { + continue; + } + + try { + $data = $chunk->getArrayData(); + } catch (JsonException) { + // try catch only needed for Symfony 6.4 + continue; + } + + if ($this->streamIsToolCall($data)) { + $toolCalls = $this->convertStreamToToolCalls($toolCalls, $data); + } + + if ([] !== $toolCalls && $this->isToolCallsStreamFinished($data)) { + yield new ToolCallResponse(...array_map($this->convertToolCall(...), $toolCalls)); + } + + if (!isset($data['choices'][0]['delta']['content'])) { + continue; + } + + yield $data['choices'][0]['delta']['content']; + } + } + + /** + * @param array $toolCalls + * @param array $data + * + * @return array + */ + private function convertStreamToToolCalls(array $toolCalls, array $data): array + { + if (!isset($data['choices'][0]['delta']['tool_calls'])) { + return $toolCalls; + } + + foreach ($data['choices'][0]['delta']['tool_calls'] as $i => $toolCall) { + if (isset($toolCall['id'])) { + // initialize tool call + $toolCalls[$i] = [ + 'id' => $toolCall['id'], + 'function' => $toolCall['function'], + ]; + continue; + } + + // add arguments delta to tool call + $toolCalls[$i]['function']['arguments'] .= $toolCall['function']['arguments']; + } + + return $toolCalls; + } + + /** + * @param array $data + */ + private function streamIsToolCall(array $data): bool + { + return isset($data['choices'][0]['delta']['tool_calls']); + } + + /** + * @param array $data + */ + private function isToolCallsStreamFinished(array $data): bool + { + return isset($data['choices'][0]['finish_reason']) && 'tool_calls' === $data['choices'][0]['finish_reason']; + } + + /** + * @param array{ + * index: integer, + * message: array{ + * role: 'assistant', + * content: ?string, + * tool_calls: array{ + * id: string, + * type: 'function', + * function: array{ + * name: string, + * arguments: string + * }, + * }, + * refusal: ?mixed + * }, + * logprobs: string, + * finish_reason: 'stop'|'length'|'tool_calls'|'content_filter', + * } $choice + */ + private function convertChoice(array $choice): Choice + { + if ('tool_calls' === $choice['finish_reason']) { + return new Choice(toolCalls: array_map([$this, 'convertToolCall'], $choice['message']['tool_calls'])); + } + + if (\in_array($choice['finish_reason'], ['stop', 'length'], true)) { + return new Choice($choice['message']['content']); + } + + throw new RuntimeException(\sprintf('Unsupported finish reason "%s".', $choice['finish_reason'])); + } + + /** + * @param array{ + * id: string, + * type: 'function', + * function: array{ + * name: string, + * arguments: string + * } + * } $toolCall + */ + private function convertToolCall(array $toolCall): ToolCall + { + $arguments = json_decode($toolCall['function']['arguments'], true, \JSON_THROW_ON_ERROR); + + return new ToolCall($toolCall['id'], $toolCall['function']['name'], $arguments); + } +} diff --git a/src/platform/src/Bridge/OpenAI/PlatformFactory.php b/src/platform/src/Bridge/OpenAI/PlatformFactory.php new file mode 100644 index 000000000..1cea3e090 --- /dev/null +++ b/src/platform/src/Bridge/OpenAI/PlatformFactory.php @@ -0,0 +1,57 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\OpenAI; + +use Symfony\AI\Platform\Bridge\OpenAI\DallE\ModelClient as DallEModelClient; +use Symfony\AI\Platform\Bridge\OpenAI\Embeddings\ModelClient as EmbeddingsModelClient; +use Symfony\AI\Platform\Bridge\OpenAI\Embeddings\ResponseConverter as EmbeddingsResponseConverter; +use Symfony\AI\Platform\Bridge\OpenAI\GPT\ModelClient as GPTModelClient; +use Symfony\AI\Platform\Bridge\OpenAI\GPT\ResponseConverter as GPTResponseConverter; +use Symfony\AI\Platform\Bridge\OpenAI\Whisper\AudioNormalizer; +use Symfony\AI\Platform\Bridge\OpenAI\Whisper\ModelClient as WhisperModelClient; +use Symfony\AI\Platform\Bridge\OpenAI\Whisper\ResponseConverter as WhisperResponseConverter; +use Symfony\AI\Platform\Contract; +use Symfony\AI\Platform\Platform; +use Symfony\Component\HttpClient\EventSourceHttpClient; +use Symfony\Contracts\HttpClient\HttpClientInterface; + +/** + * @author Christopher Hertel + */ +final readonly class PlatformFactory +{ + public static function create( + #[\SensitiveParameter] + string $apiKey, + ?HttpClientInterface $httpClient = null, + ): Platform { + $httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient); + + $dallEModelClient = new DallEModelClient($httpClient, $apiKey); + + return new Platform( + [ + new GPTModelClient($httpClient, $apiKey), + new EmbeddingsModelClient($httpClient, $apiKey), + $dallEModelClient, + new WhisperModelClient($httpClient, $apiKey), + ], + [ + new GPTResponseConverter(), + new EmbeddingsResponseConverter(), + $dallEModelClient, + new WhisperResponseConverter(), + ], + Contract::create(new AudioNormalizer()), + ); + } +} diff --git a/src/platform/src/Bridge/OpenAI/TokenOutputProcessor.php b/src/platform/src/Bridge/OpenAI/TokenOutputProcessor.php new file mode 100644 index 000000000..5df6da37b --- /dev/null +++ b/src/platform/src/Bridge/OpenAI/TokenOutputProcessor.php @@ -0,0 +1,52 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\OpenAI; + +use Symfony\AI\Agent\Output; +use Symfony\AI\Agent\OutputProcessorInterface; +use Symfony\AI\Platform\Response\StreamResponse; + +/** + * @author Denis Zunke + */ +final class TokenOutputProcessor implements OutputProcessorInterface +{ + public function processOutput(Output $output): void + { + if ($output->response instanceof StreamResponse) { + // Streams have to be handled manually as the tokens are part of the streamed chunks + return; + } + + $rawResponse = $output->response->getRawResponse(); + if (null === $rawResponse) { + return; + } + + $metadata = $output->response->getMetadata(); + + $metadata->add( + 'remaining_tokens', + (int) $rawResponse->getHeaders(false)['x-ratelimit-remaining-tokens'][0], + ); + + $content = $rawResponse->toArray(false); + + if (!\array_key_exists('usage', $content)) { + return; + } + + $metadata->add('prompt_tokens', $content['usage']['prompt_tokens'] ?? null); + $metadata->add('completion_tokens', $content['usage']['completion_tokens'] ?? null); + $metadata->add('total_tokens', $content['usage']['total_tokens'] ?? null); + } +} diff --git a/src/platform/src/Bridge/OpenAI/Whisper.php b/src/platform/src/Bridge/OpenAI/Whisper.php new file mode 100644 index 000000000..3a70bd04d --- /dev/null +++ b/src/platform/src/Bridge/OpenAI/Whisper.php @@ -0,0 +1,36 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\OpenAI; + +use Symfony\AI\Platform\Capability; +use Symfony\AI\Platform\Model; + +/** + * @author Christopher Hertel + */ +class Whisper extends Model +{ + public const WHISPER_1 = 'whisper-1'; + + /** + * @param array $options + */ + public function __construct(string $name = self::WHISPER_1, array $options = []) + { + $capabilities = [ + Capability::INPUT_AUDIO, + Capability::OUTPUT_TEXT, + ]; + + parent::__construct($name, $capabilities, $options); + } +} diff --git a/src/platform/src/Bridge/OpenAI/Whisper/AudioNormalizer.php b/src/platform/src/Bridge/OpenAI/Whisper/AudioNormalizer.php new file mode 100644 index 000000000..d33d8bc8c --- /dev/null +++ b/src/platform/src/Bridge/OpenAI/Whisper/AudioNormalizer.php @@ -0,0 +1,48 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\OpenAI\Whisper; + +use Symfony\AI\Platform\Bridge\OpenAI\Whisper; +use Symfony\AI\Platform\Contract; +use Symfony\AI\Platform\Message\Content\Audio; +use Symfony\Component\Serializer\Normalizer\NormalizerInterface; + +/** + * @author Christopher Hertel + */ +final class AudioNormalizer implements NormalizerInterface +{ + public function supportsNormalization(mixed $data, ?string $format = null, array $context = []): bool + { + return $data instanceof Audio && $context[Contract::CONTEXT_MODEL] instanceof Whisper; + } + + public function getSupportedTypes(?string $format): array + { + return [ + Audio::class => true, + ]; + } + + /** + * @param Audio $data + * + * @return array{model: string, file: resource} + */ + public function normalize(mixed $data, ?string $format = null, array $context = []): array + { + return [ + 'model' => $context[Contract::CONTEXT_MODEL]->getName(), + 'file' => $data->asResource(), + ]; + } +} diff --git a/src/platform/src/Bridge/OpenAI/Whisper/ModelClient.php b/src/platform/src/Bridge/OpenAI/Whisper/ModelClient.php new file mode 100644 index 000000000..fb6462e3b --- /dev/null +++ b/src/platform/src/Bridge/OpenAI/Whisper/ModelClient.php @@ -0,0 +1,47 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\OpenAI\Whisper; + +use Symfony\AI\Platform\Bridge\OpenAI\Whisper; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\ModelClientInterface as BaseModelClient; +use Symfony\Contracts\HttpClient\HttpClientInterface; +use Symfony\Contracts\HttpClient\ResponseInterface; +use Webmozart\Assert\Assert; + +/** + * @author Christopher Hertel + */ +final readonly class ModelClient implements BaseModelClient +{ + public function __construct( + private HttpClientInterface $httpClient, + #[\SensitiveParameter] + private string $apiKey, + ) { + Assert::stringNotEmpty($apiKey, 'The API key must not be empty.'); + } + + public function supports(Model $model): bool + { + return $model instanceof Whisper; + } + + public function request(Model $model, array|string $payload, array $options = []): ResponseInterface + { + return $this->httpClient->request('POST', 'https://api.openai.com/v1/audio/transcriptions', [ + 'auth_bearer' => $this->apiKey, + 'headers' => ['Content-Type' => 'multipart/form-data'], + 'body' => array_merge($options, $payload, ['model' => $model->getName()]), + ]); + } +} diff --git a/src/platform/src/Bridge/OpenAI/Whisper/ResponseConverter.php b/src/platform/src/Bridge/OpenAI/Whisper/ResponseConverter.php new file mode 100644 index 000000000..094421c2e --- /dev/null +++ b/src/platform/src/Bridge/OpenAI/Whisper/ResponseConverter.php @@ -0,0 +1,37 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\OpenAI\Whisper; + +use Symfony\AI\Platform\Bridge\OpenAI\Whisper; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\Response\ResponseInterface as LlmResponse; +use Symfony\AI\Platform\Response\TextResponse; +use Symfony\AI\Platform\ResponseConverterInterface as BaseResponseConverter; +use Symfony\Contracts\HttpClient\ResponseInterface as HttpResponse; + +/** + * @author Christopher Hertel + */ +final class ResponseConverter implements BaseResponseConverter +{ + public function supports(Model $model): bool + { + return $model instanceof Whisper; + } + + public function convert(HttpResponse $response, array $options = []): LlmResponse + { + $data = $response->toArray(); + + return new TextResponse($data['text']); + } +} diff --git a/src/platform/src/Bridge/OpenRouter/Client.php b/src/platform/src/Bridge/OpenRouter/Client.php new file mode 100644 index 000000000..65e5cd626 --- /dev/null +++ b/src/platform/src/Bridge/OpenRouter/Client.php @@ -0,0 +1,70 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\OpenRouter; + +use Symfony\AI\Platform\Exception\RuntimeException; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\ModelClientInterface; +use Symfony\AI\Platform\Response\ResponseInterface as LlmResponse; +use Symfony\AI\Platform\Response\TextResponse; +use Symfony\AI\Platform\ResponseConverterInterface; +use Symfony\Component\HttpClient\EventSourceHttpClient; +use Symfony\Contracts\HttpClient\HttpClientInterface; +use Symfony\Contracts\HttpClient\ResponseInterface; +use Webmozart\Assert\Assert; + +/** + * @author rglozman + */ +final readonly class Client implements ModelClientInterface, ResponseConverterInterface +{ + private EventSourceHttpClient $httpClient; + + public function __construct( + HttpClientInterface $httpClient, + #[\SensitiveParameter] private string $apiKey, + ) { + $this->httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient); + Assert::stringNotEmpty($apiKey, 'The API key must not be empty.'); + Assert::startsWith($apiKey, 'sk-', 'The API key must start with "sk-".'); + } + + public function supports(Model $model): bool + { + return true; + } + + public function request(Model $model, array|string $payload, array $options = []): ResponseInterface + { + return $this->httpClient->request('POST', 'https://openrouter.ai/api/v1/chat/completions', [ + 'auth_bearer' => $this->apiKey, + 'json' => array_merge($options, $payload), + ]); + } + + public function convert(ResponseInterface $response, array $options = []): LlmResponse + { + dump($response->getContent(false)); + + $data = $response->toArray(); + + if (!isset($data['choices'][0]['message'])) { + throw new RuntimeException('Response does not contain message'); + } + + if (!isset($data['choices'][0]['message']['content'])) { + throw new RuntimeException('Message does not contain content'); + } + + return new TextResponse($data['choices'][0]['message']['content']); + } +} diff --git a/src/platform/src/Bridge/OpenRouter/PlatformFactory.php b/src/platform/src/Bridge/OpenRouter/PlatformFactory.php new file mode 100644 index 000000000..15b53da27 --- /dev/null +++ b/src/platform/src/Bridge/OpenRouter/PlatformFactory.php @@ -0,0 +1,41 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\OpenRouter; + +use Symfony\AI\Platform\Bridge\Google\Contract\AssistantMessageNormalizer; +use Symfony\AI\Platform\Bridge\Google\Contract\MessageBagNormalizer; +use Symfony\AI\Platform\Bridge\Google\Contract\UserMessageNormalizer; +use Symfony\AI\Platform\Contract; +use Symfony\AI\Platform\Platform; +use Symfony\Component\HttpClient\EventSourceHttpClient; +use Symfony\Contracts\HttpClient\HttpClientInterface; + +/** + * @author rglozman + */ +final class PlatformFactory +{ + public static function create( + #[\SensitiveParameter] + string $apiKey, + ?HttpClientInterface $httpClient = null, + ): Platform { + $httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient); + $handler = new Client($httpClient, $apiKey); + + return new Platform([$handler], [$handler], Contract::create( + new AssistantMessageNormalizer(), + new MessageBagNormalizer(), + new UserMessageNormalizer(), + )); + } +} diff --git a/src/platform/src/Bridge/Replicate/Client.php b/src/platform/src/Bridge/Replicate/Client.php new file mode 100644 index 000000000..f3483352b --- /dev/null +++ b/src/platform/src/Bridge/Replicate/Client.php @@ -0,0 +1,64 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Replicate; + +use Symfony\Component\Clock\ClockInterface; +use Symfony\Contracts\HttpClient\HttpClientInterface; +use Symfony\Contracts\HttpClient\ResponseInterface; + +/** + * @author Christopher Hertel + */ +final readonly class Client +{ + public function __construct( + private HttpClientInterface $httpClient, + private ClockInterface $clock, + #[\SensitiveParameter] private string $apiKey, + ) { + } + + /** + * @param string $model The model name on Replicate, e.g. "meta/meta-llama-3.1-405b-instruct" + * @param array $body + */ + public function request(string $model, string $endpoint, array $body): ResponseInterface + { + $url = \sprintf('https://api.replicate.com/v1/models/%s/%s', $model, $endpoint); + + $response = $this->httpClient->request('POST', $url, [ + 'headers' => ['Content-Type' => 'application/json'], + 'auth_bearer' => $this->apiKey, + 'json' => ['input' => $body], + ]); + $data = $response->toArray(); + + while (!\in_array($data['status'], ['succeeded', 'failed', 'canceled'], true)) { + $this->clock->sleep(1); // we need to wait until the prediction is ready + + $response = $this->getResponse($data['id']); + $data = $response->toArray(); + } + + return $response; + } + + private function getResponse(string $id): ResponseInterface + { + $url = \sprintf('https://api.replicate.com/v1/predictions/%s', $id); + + return $this->httpClient->request('GET', $url, [ + 'headers' => ['Content-Type' => 'application/json'], + 'auth_bearer' => $this->apiKey, + ]); + } +} diff --git a/src/platform/src/Bridge/Replicate/Contract/LlamaMessageBagNormalizer.php b/src/platform/src/Bridge/Replicate/Contract/LlamaMessageBagNormalizer.php new file mode 100644 index 000000000..f5e7376af --- /dev/null +++ b/src/platform/src/Bridge/Replicate/Contract/LlamaMessageBagNormalizer.php @@ -0,0 +1,53 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Replicate\Contract; + +use Symfony\AI\Platform\Bridge\Meta\Llama; +use Symfony\AI\Platform\Bridge\Meta\LlamaPromptConverter; +use Symfony\AI\Platform\Contract\Normalizer\ModelContractNormalizer; +use Symfony\AI\Platform\Message\MessageBagInterface; +use Symfony\AI\Platform\Message\SystemMessage; +use Symfony\AI\Platform\Model; + +/** + * @author Christopher Hertel + */ +final class LlamaMessageBagNormalizer extends ModelContractNormalizer +{ + public function __construct( + private readonly LlamaPromptConverter $promptConverter = new LlamaPromptConverter(), + ) { + } + + protected function supportedDataClass(): string + { + return MessageBagInterface::class; + } + + protected function supportsModel(Model $model): bool + { + return $model instanceof Llama; + } + + /** + * @param MessageBagInterface $data + * + * @return array{system: string, prompt: string} + */ + public function normalize(mixed $data, ?string $format = null, array $context = []): array + { + return [ + 'system' => $this->promptConverter->convertMessage($data->getSystemMessage() ?? new SystemMessage('')), + 'prompt' => $this->promptConverter->convertToPrompt($data->withoutSystemMessage()), + ]; + } +} diff --git a/src/platform/src/Bridge/Replicate/LlamaModelClient.php b/src/platform/src/Bridge/Replicate/LlamaModelClient.php new file mode 100644 index 000000000..18c36bef5 --- /dev/null +++ b/src/platform/src/Bridge/Replicate/LlamaModelClient.php @@ -0,0 +1,41 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Replicate; + +use Symfony\AI\Platform\Bridge\Meta\Llama; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\ModelClientInterface; +use Symfony\Contracts\HttpClient\ResponseInterface; +use Webmozart\Assert\Assert; + +/** + * @author Christopher Hertel + */ +final readonly class LlamaModelClient implements ModelClientInterface +{ + public function __construct( + private Client $client, + ) { + } + + public function supports(Model $model): bool + { + return $model instanceof Llama; + } + + public function request(Model $model, array|string $payload, array $options = []): ResponseInterface + { + Assert::isInstanceOf($model, Llama::class); + + return $this->client->request(\sprintf('meta/meta-%s', $model->getName()), 'predictions', $payload); + } +} diff --git a/src/platform/src/Bridge/Replicate/LlamaResponseConverter.php b/src/platform/src/Bridge/Replicate/LlamaResponseConverter.php new file mode 100644 index 000000000..65d1cf19e --- /dev/null +++ b/src/platform/src/Bridge/Replicate/LlamaResponseConverter.php @@ -0,0 +1,42 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Replicate; + +use Symfony\AI\Platform\Bridge\Meta\Llama; +use Symfony\AI\Platform\Exception\RuntimeException; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\Response\ResponseInterface as LlmResponse; +use Symfony\AI\Platform\Response\TextResponse; +use Symfony\AI\Platform\ResponseConverterInterface; +use Symfony\Contracts\HttpClient\ResponseInterface as HttpResponse; + +/** + * @author Christopher Hertel + */ +final readonly class LlamaResponseConverter implements ResponseConverterInterface +{ + public function supports(Model $model): bool + { + return $model instanceof Llama; + } + + public function convert(HttpResponse $response, array $options = []): LlmResponse + { + $data = $response->toArray(); + + if (!isset($data['output'])) { + throw new RuntimeException('Response does not contain output'); + } + + return new TextResponse(implode('', $data['output'])); + } +} diff --git a/src/platform/src/Bridge/Replicate/PlatformFactory.php b/src/platform/src/Bridge/Replicate/PlatformFactory.php new file mode 100644 index 000000000..51e9d7a86 --- /dev/null +++ b/src/platform/src/Bridge/Replicate/PlatformFactory.php @@ -0,0 +1,37 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Replicate; + +use Symfony\AI\Platform\Bridge\Replicate\Contract\LlamaMessageBagNormalizer; +use Symfony\AI\Platform\Contract; +use Symfony\AI\Platform\Platform; +use Symfony\Component\Clock\Clock; +use Symfony\Component\HttpClient\HttpClient; +use Symfony\Contracts\HttpClient\HttpClientInterface; + +/** + * @author Christopher Hertel + */ +final class PlatformFactory +{ + public static function create( + #[\SensitiveParameter] + string $apiKey, + ?HttpClientInterface $httpClient = null, + ): Platform { + return new Platform( + [new LlamaModelClient(new Client($httpClient ?? HttpClient::create(), new Clock(), $apiKey))], + [new LlamaResponseConverter()], + Contract::create(new LlamaMessageBagNormalizer()), + ); + } +} diff --git a/src/platform/src/Bridge/TransformersPHP/Platform.php b/src/platform/src/Bridge/TransformersPHP/Platform.php new file mode 100644 index 000000000..166afc9aa --- /dev/null +++ b/src/platform/src/Bridge/TransformersPHP/Platform.php @@ -0,0 +1,52 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\TransformersPHP; + +use Codewithkyrian\Transformers\Pipelines\Task; +use Symfony\AI\Platform\Exception\InvalidArgumentException; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\PlatformInterface; +use Symfony\AI\Platform\Response\ObjectResponse; +use Symfony\AI\Platform\Response\ResponseInterface; +use Symfony\AI\Platform\Response\TextResponse; + +use function Codewithkyrian\Transformers\Pipelines\pipeline; + +/** + * @author Christopher Hertel + */ +final class Platform implements PlatformInterface +{ + public function request(Model $model, object|array|string $input, array $options = []): ResponseInterface + { + if (null === $task = $options['task'] ?? null) { + throw new InvalidArgumentException('The task option is required.'); + } + + $pipeline = pipeline( + $options['task'], + $model->getName(), + $options['quantized'] ?? true, + $options['config'] ?? null, + $options['cacheDir'] ?? null, + $options['revision'] ?? 'main', + $options['modelFilename'] ?? null, + ); + + $data = $pipeline($input); + + return match ($task) { + Task::Text2TextGeneration => new TextResponse($data[0]['generated_text']), + default => new ObjectResponse($data), + }; + } +} diff --git a/src/platform/src/Bridge/TransformersPHP/PlatformFactory.php b/src/platform/src/Bridge/TransformersPHP/PlatformFactory.php new file mode 100644 index 000000000..8dbb789a9 --- /dev/null +++ b/src/platform/src/Bridge/TransformersPHP/PlatformFactory.php @@ -0,0 +1,30 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\TransformersPHP; + +use Codewithkyrian\Transformers\Transformers; +use Symfony\AI\Platform\Exception\RuntimeException; + +/** + * @author Christopher Hertel + */ +final readonly class PlatformFactory +{ + public static function create(): Platform + { + if (!class_exists(Transformers::class)) { + throw new RuntimeException('TransformersPHP is not installed. Please install it using "composer require codewithkyrian/transformers".'); + } + + return new Platform(); + } +} diff --git a/src/platform/src/Bridge/Voyage/ModelHandler.php b/src/platform/src/Bridge/Voyage/ModelHandler.php new file mode 100644 index 000000000..e38e92047 --- /dev/null +++ b/src/platform/src/Bridge/Voyage/ModelHandler.php @@ -0,0 +1,63 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Voyage; + +use Symfony\AI\Platform\Exception\RuntimeException; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\ModelClientInterface; +use Symfony\AI\Platform\Response\ResponseInterface as LlmResponse; +use Symfony\AI\Platform\Response\VectorResponse; +use Symfony\AI\Platform\ResponseConverterInterface; +use Symfony\AI\Platform\Vector\Vector; +use Symfony\Contracts\HttpClient\HttpClientInterface; +use Symfony\Contracts\HttpClient\ResponseInterface; + +/** + * @author Christopher Hertel + */ +final readonly class ModelHandler implements ModelClientInterface, ResponseConverterInterface +{ + public function __construct( + private HttpClientInterface $httpClient, + #[\SensitiveParameter] private string $apiKey, + ) { + } + + public function supports(Model $model): bool + { + return $model instanceof Voyage; + } + + public function request(Model $model, object|string|array $payload, array $options = []): ResponseInterface + { + return $this->httpClient->request('POST', 'https://api.voyageai.com/v1/embeddings', [ + 'auth_bearer' => $this->apiKey, + 'json' => [ + 'model' => $model->getName(), + 'input' => $payload, + ], + ]); + } + + public function convert(ResponseInterface $response, array $options = []): LlmResponse + { + $response = $response->toArray(); + + if (!isset($response['data'])) { + throw new RuntimeException('Response does not contain embedding data'); + } + + $vectors = array_map(fn (array $data) => new Vector($data['embedding']), $response['data']); + + return new VectorResponse($vectors[0]); + } +} diff --git a/src/platform/src/Bridge/Voyage/PlatformFactory.php b/src/platform/src/Bridge/Voyage/PlatformFactory.php new file mode 100644 index 000000000..8497a9c02 --- /dev/null +++ b/src/platform/src/Bridge/Voyage/PlatformFactory.php @@ -0,0 +1,33 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Voyage; + +use Symfony\AI\Platform\Platform; +use Symfony\Component\HttpClient\EventSourceHttpClient; +use Symfony\Contracts\HttpClient\HttpClientInterface; + +/** + * @author Christopher Hertel + */ +final class PlatformFactory +{ + public static function create( + #[\SensitiveParameter] + string $apiKey, + ?HttpClientInterface $httpClient = null, + ): Platform { + $httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient); + $handler = new ModelHandler($httpClient, $apiKey); + + return new Platform([$handler], [$handler]); + } +} diff --git a/src/platform/src/Bridge/Voyage/Voyage.php b/src/platform/src/Bridge/Voyage/Voyage.php new file mode 100644 index 000000000..955748499 --- /dev/null +++ b/src/platform/src/Bridge/Voyage/Voyage.php @@ -0,0 +1,36 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Voyage; + +use Symfony\AI\Platform\Capability; +use Symfony\AI\Platform\Model; + +/** + * @author Christopher Hertel + */ +class Voyage extends Model +{ + public const V3 = 'voyage-3'; + public const V3_LITE = 'voyage-3-lite'; + public const FINANCE_2 = 'voyage-finance-2'; + public const MULTILINGUAL_2 = 'voyage-multilingual-2'; + public const LAW_2 = 'voyage-law-2'; + public const CODE_2 = 'voyage-code-2'; + + /** + * @param array $options + */ + public function __construct(string $name = self::V3, array $options = []) + { + parent::__construct($name, [Capability::INPUT_MULTIPLE], $options); + } +} diff --git a/src/platform/src/Capability.php b/src/platform/src/Capability.php new file mode 100644 index 000000000..78a7c1b02 --- /dev/null +++ b/src/platform/src/Capability.php @@ -0,0 +1,36 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform; + +/** + * @author Christopher Hertel + */ +class Capability +{ + // INPUT + public const INPUT_AUDIO = 'input-audio'; + public const INPUT_IMAGE = 'input-image'; + public const INPUT_MESSAGES = 'input-messages'; + public const INPUT_MULTIPLE = 'input-multiple'; + public const INPUT_PDF = 'input-pdf'; + public const INPUT_TEXT = 'input-text'; + + // OUTPUT + public const OUTPUT_AUDIO = 'output-audio'; + public const OUTPUT_IMAGE = 'output-image'; + public const OUTPUT_STREAMING = 'output-streaming'; + public const OUTPUT_STRUCTURED = 'output-structured'; + public const OUTPUT_TEXT = 'output-text'; + + // FUNCTIONALITY + public const TOOL_CALLING = 'tool-calling'; +} diff --git a/src/platform/src/Contract.php b/src/platform/src/Contract.php new file mode 100644 index 000000000..76662f10e --- /dev/null +++ b/src/platform/src/Contract.php @@ -0,0 +1,86 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform; + +use Symfony\AI\Platform\Contract\Normalizer\Message\AssistantMessageNormalizer; +use Symfony\AI\Platform\Contract\Normalizer\Message\Content\AudioNormalizer; +use Symfony\AI\Platform\Contract\Normalizer\Message\Content\ImageNormalizer; +use Symfony\AI\Platform\Contract\Normalizer\Message\Content\ImageUrlNormalizer; +use Symfony\AI\Platform\Contract\Normalizer\Message\Content\TextNormalizer; +use Symfony\AI\Platform\Contract\Normalizer\Message\MessageBagNormalizer; +use Symfony\AI\Platform\Contract\Normalizer\Message\SystemMessageNormalizer; +use Symfony\AI\Platform\Contract\Normalizer\Message\ToolCallMessageNormalizer; +use Symfony\AI\Platform\Contract\Normalizer\Message\UserMessageNormalizer; +use Symfony\AI\Platform\Contract\Normalizer\Response\ToolCallNormalizer; +use Symfony\AI\Platform\Contract\Normalizer\ToolNormalizer; +use Symfony\AI\Platform\Tool\Tool; +use Symfony\Component\Serializer\Normalizer\NormalizerInterface; +use Symfony\Component\Serializer\Serializer; + +/** + * @author Christopher Hertel + */ +final readonly class Contract +{ + public const CONTEXT_MODEL = 'model'; + + public function __construct( + private NormalizerInterface $normalizer, + ) { + } + + public static function create(NormalizerInterface ...$normalizer): self + { + // Messages + $normalizer[] = new MessageBagNormalizer(); + $normalizer[] = new AssistantMessageNormalizer(); + $normalizer[] = new SystemMessageNormalizer(); + $normalizer[] = new ToolCallMessageNormalizer(); + $normalizer[] = new UserMessageNormalizer(); + + // Message Content + $normalizer[] = new AudioNormalizer(); + $normalizer[] = new ImageNormalizer(); + $normalizer[] = new ImageUrlNormalizer(); + $normalizer[] = new TextNormalizer(); + + // Options + $normalizer[] = new ToolNormalizer(); + + // Response + $normalizer[] = new ToolCallNormalizer(); + + return new self( + new Serializer($normalizer), + ); + } + + /** + * @param object|array|string $input + * + * @return array|string + */ + public function createRequestPayload(Model $model, object|array|string $input): string|array + { + return $this->normalizer->normalize($input, context: [self::CONTEXT_MODEL => $model]); + } + + /** + * @param Tool[] $tools + * + * @return array + */ + public function createToolOption(array $tools, Model $model): array + { + return $this->normalizer->normalize($tools, context: [self::CONTEXT_MODEL => $model]); + } +} diff --git a/src/platform/src/Contract/JsonSchema/Attribute/With.php b/src/platform/src/Contract/JsonSchema/Attribute/With.php new file mode 100644 index 000000000..02f6cb76a --- /dev/null +++ b/src/platform/src/Contract/JsonSchema/Attribute/With.php @@ -0,0 +1,148 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Contract\JsonSchema\Attribute; + +use Webmozart\Assert\Assert; + +/** + * @author Oskar Stark + */ +#[\Attribute(\Attribute::TARGET_PARAMETER)] +final readonly class With +{ + /** + * @param list|null $enum + * @param string|int|string[]|null $const + */ + public function __construct( + // can be used by many types + public ?array $enum = null, + public string|int|array|null $const = null, + + // string + public ?string $pattern = null, + public ?int $minLength = null, + public ?int $maxLength = null, + + // integer + public ?int $minimum = null, + public ?int $maximum = null, + public ?int $multipleOf = null, + public ?int $exclusiveMinimum = null, + public ?int $exclusiveMaximum = null, + + // array + public ?int $minItems = null, + public ?int $maxItems = null, + public ?bool $uniqueItems = null, + public ?int $minContains = null, + public ?int $maxContains = null, + + // object + public ?bool $required = null, + public ?int $minProperties = null, + public ?int $maxProperties = null, + public ?bool $dependentRequired = null, + ) { + if (\is_array($enum)) { + Assert::allString($enum); + } + + if (\is_string($const)) { + Assert::stringNotEmpty(trim($const)); + } + + if (\is_string($pattern)) { + Assert::stringNotEmpty(trim($pattern)); + } + + if (\is_int($minLength)) { + Assert::greaterThanEq($minLength, 0); + + if (\is_int($maxLength)) { + Assert::greaterThanEq($maxLength, $minLength); + } + } + + if (\is_int($maxLength)) { + Assert::greaterThanEq($maxLength, 0); + } + + if (\is_int($minimum)) { + Assert::greaterThanEq($minimum, 0); + + if (\is_int($maximum)) { + Assert::greaterThanEq($maximum, $minimum); + } + } + + if (\is_int($maximum)) { + Assert::greaterThanEq($maximum, 0); + } + + if (\is_int($multipleOf)) { + Assert::greaterThanEq($multipleOf, 0); + } + + if (\is_int($exclusiveMinimum)) { + Assert::greaterThanEq($exclusiveMinimum, 0); + + if (\is_int($exclusiveMaximum)) { + Assert::greaterThanEq($exclusiveMaximum, $exclusiveMinimum); + } + } + + if (\is_int($exclusiveMaximum)) { + Assert::greaterThanEq($exclusiveMaximum, 0); + } + + if (\is_int($minItems)) { + Assert::greaterThanEq($minItems, 0); + + if (\is_int($maxItems)) { + Assert::greaterThanEq($maxItems, $minItems); + } + } + + if (\is_int($maxItems)) { + Assert::greaterThanEq($maxItems, 0); + } + + if (\is_bool($uniqueItems)) { + Assert::true($uniqueItems); + } + + if (\is_int($minContains)) { + Assert::greaterThanEq($minContains, 0); + + if (\is_int($maxContains)) { + Assert::greaterThanEq($maxContains, $minContains); + } + } + + if (\is_int($maxContains)) { + Assert::greaterThanEq($maxContains, 0); + } + + if (\is_int($minProperties)) { + Assert::greaterThanEq($minProperties, 0); + + if (\is_int($maxProperties)) { + Assert::greaterThanEq($maxProperties, $minProperties); + } + } + + if (\is_int($maxProperties)) { + Assert::greaterThanEq($maxProperties, 0); + } + } +} diff --git a/src/platform/src/Contract/JsonSchema/DescriptionParser.php b/src/platform/src/Contract/JsonSchema/DescriptionParser.php new file mode 100644 index 000000000..cdb8a3afb --- /dev/null +++ b/src/platform/src/Contract/JsonSchema/DescriptionParser.php @@ -0,0 +1,59 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Contract\JsonSchema; + +/** + * @author Christopher Hertel + */ +final readonly class DescriptionParser +{ + public function getDescription(\ReflectionProperty|\ReflectionParameter $reflector): string + { + if ($reflector instanceof \ReflectionProperty) { + return $this->fromProperty($reflector); + } + + return $this->fromParameter($reflector); + } + + private function fromProperty(\ReflectionProperty $property): string + { + $comment = $property->getDocComment(); + + if (\is_string($comment) && preg_match('/@var\s+[a-zA-Z\\\\]+\s+((.*)(?=\*)|.*)/', $comment, $matches)) { + return trim($matches[1]); + } + + $class = $property->getDeclaringClass(); + if ($class->hasMethod('__construct')) { + return $this->fromParameter( + new \ReflectionParameter([$class->getName(), '__construct'], $property->getName()) + ); + } + + return ''; + } + + private function fromParameter(\ReflectionParameter $parameter): string + { + $comment = $parameter->getDeclaringFunction()->getDocComment(); + if (!$comment) { + return ''; + } + + if (preg_match('/@param\s+\S+\s+\$'.preg_quote($parameter->getName(), '/').'\s+((.*)(?=\*)|.*)/', $comment, $matches)) { + return trim($matches[1]); + } + + return ''; + } +} diff --git a/src/platform/src/Contract/JsonSchema/Factory.php b/src/platform/src/Contract/JsonSchema/Factory.php new file mode 100644 index 000000000..cbfc66b92 --- /dev/null +++ b/src/platform/src/Contract/JsonSchema/Factory.php @@ -0,0 +1,185 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Contract\JsonSchema; + +use Symfony\AI\Platform\Contract\JsonSchema\Attribute\With; +use Symfony\AI\Platform\Exception\InvalidArgumentException; +use Symfony\Component\TypeInfo\Type; +use Symfony\Component\TypeInfo\Type\BuiltinType; +use Symfony\Component\TypeInfo\Type\CollectionType; +use Symfony\Component\TypeInfo\Type\ObjectType; +use Symfony\Component\TypeInfo\TypeIdentifier; +use Symfony\Component\TypeInfo\TypeResolver\TypeResolver; + +/** + * @phpstan-type JsonSchema array{ + * type: 'object', + * properties: array, + * const?: string|int|list, + * pattern?: string, + * minLength?: int, + * maxLength?: int, + * minimum?: int, + * maximum?: int, + * multipleOf?: int, + * exclusiveMinimum?: int, + * exclusiveMaximum?: int, + * minItems?: int, + * maxItems?: int, + * uniqueItems?: bool, + * minContains?: int, + * maxContains?: int, + * required?: bool, + * minProperties?: int, + * maxProperties?: int, + * dependentRequired?: bool, + * }>, + * required: list, + * additionalProperties: false, + * } + * + * @author Christopher Hertel + */ +final readonly class Factory +{ + private TypeResolver $typeResolver; + + public function __construct( + private DescriptionParser $descriptionParser = new DescriptionParser(), + ?TypeResolver $typeResolver = null, + ) { + $this->typeResolver = $typeResolver ?? TypeResolver::create(); + } + + /** + * @return JsonSchema|null + */ + public function buildParameters(string $className, string $methodName): ?array + { + $reflection = new \ReflectionMethod($className, $methodName); + + return $this->convertTypes($reflection->getParameters()); + } + + /** + * @return JsonSchema|null + */ + public function buildProperties(string $className): ?array + { + $reflection = new \ReflectionClass($className); + + return $this->convertTypes($reflection->getProperties()); + } + + /** + * @param list<\ReflectionProperty|\ReflectionParameter> $elements + * + * @return JsonSchema|null + */ + private function convertTypes(array $elements): ?array + { + if (0 === \count($elements)) { + return null; + } + + $result = [ + 'type' => 'object', + 'properties' => [], + 'required' => [], + 'additionalProperties' => false, + ]; + + foreach ($elements as $element) { + $name = $element->getName(); + $type = $this->typeResolver->resolve($element); + $schema = $this->getTypeSchema($type); + + if ($type->isNullable()) { + $schema['type'] = [$schema['type'], 'null']; + } elseif (!($element instanceof \ReflectionParameter && $element->isOptional())) { + $result['required'][] = $name; + } + + $description = $this->descriptionParser->getDescription($element); + if ('' !== $description) { + $schema['description'] = $description; + } + + // Check for ToolParameter attributes + $attributes = $element->getAttributes(With::class); + if (\count($attributes) > 0) { + $attributeState = array_filter((array) $attributes[0]->newInstance(), fn ($value) => null !== $value); + $schema = array_merge($schema, $attributeState); + } + + $result['properties'][$name] = $schema; + } + + return $result; + } + + /** + * @return array + */ + private function getTypeSchema(Type $type): array + { + switch (true) { + case $type->isIdentifiedBy(TypeIdentifier::INT): + return ['type' => 'integer']; + + case $type->isIdentifiedBy(TypeIdentifier::FLOAT): + return ['type' => 'number']; + + case $type->isIdentifiedBy(TypeIdentifier::BOOL): + return ['type' => 'boolean']; + + case $type->isIdentifiedBy(TypeIdentifier::ARRAY): + \assert($type instanceof CollectionType); + $collectionValueType = $type->getCollectionValueType(); + + if ($collectionValueType->isIdentifiedBy(TypeIdentifier::OBJECT)) { + \assert($collectionValueType instanceof ObjectType); + + return [ + 'type' => 'array', + 'items' => $this->buildProperties($collectionValueType->getClassName()), + ]; + } + + return [ + 'type' => 'array', + 'items' => $this->getTypeSchema($collectionValueType), + ]; + + case $type->isIdentifiedBy(TypeIdentifier::OBJECT): + if ($type instanceof BuiltinType) { + throw new InvalidArgumentException('Cannot build schema from plain object type.'); + } + \assert($type instanceof ObjectType); + if (\in_array($type->getClassName(), ['DateTime', 'DateTimeImmutable', 'DateTimeInterface'], true)) { + return ['type' => 'string', 'format' => 'date-time']; + } else { + // Recursively build the schema for an object type + return $this->buildProperties($type->getClassName()) ?? ['type' => 'object']; + } + + // no break + case $type->isIdentifiedBy(TypeIdentifier::STRING): + default: + // Fallback to string for any unhandled types + return ['type' => 'string']; + } + } +} diff --git a/src/platform/src/Contract/Normalizer/Message/AssistantMessageNormalizer.php b/src/platform/src/Contract/Normalizer/Message/AssistantMessageNormalizer.php new file mode 100644 index 000000000..c33703b41 --- /dev/null +++ b/src/platform/src/Contract/Normalizer/Message/AssistantMessageNormalizer.php @@ -0,0 +1,59 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Contract\Normalizer\Message; + +use Symfony\AI\Platform\Message\AssistantMessage; +use Symfony\Component\Serializer\Normalizer\NormalizerAwareInterface; +use Symfony\Component\Serializer\Normalizer\NormalizerAwareTrait; +use Symfony\Component\Serializer\Normalizer\NormalizerInterface; + +/** + * @author Christopher Hertel + */ +final class AssistantMessageNormalizer implements NormalizerInterface, NormalizerAwareInterface +{ + use NormalizerAwareTrait; + + public function supportsNormalization(mixed $data, ?string $format = null, array $context = []): bool + { + return $data instanceof AssistantMessage; + } + + public function getSupportedTypes(?string $format): array + { + return [ + AssistantMessage::class => true, + ]; + } + + /** + * @param AssistantMessage $data + * + * @return array{role: 'assistant', content: string} + */ + public function normalize(mixed $data, ?string $format = null, array $context = []): array + { + $array = [ + 'role' => $data->getRole()->value, + ]; + + if (null !== $data->content) { + $array['content'] = $data->content; + } + + if ($data->hasToolCalls()) { + $array['tool_calls'] = $this->normalizer->normalize($data->toolCalls, $format, $context); + } + + return $array; + } +} diff --git a/src/platform/src/Contract/Normalizer/Message/Content/AudioNormalizer.php b/src/platform/src/Contract/Normalizer/Message/Content/AudioNormalizer.php new file mode 100644 index 000000000..3fcf220ec --- /dev/null +++ b/src/platform/src/Contract/Normalizer/Message/Content/AudioNormalizer.php @@ -0,0 +1,56 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Contract\Normalizer\Message\Content; + +use Symfony\AI\Platform\Message\Content\Audio; +use Symfony\Component\Serializer\Normalizer\NormalizerInterface; + +/** + * @author Christopher Hertel + */ +final class AudioNormalizer implements NormalizerInterface +{ + public function supportsNormalization(mixed $data, ?string $format = null, array $context = []): bool + { + return $data instanceof Audio; + } + + public function getSupportedTypes(?string $format): array + { + return [ + Audio::class => true, + ]; + } + + /** + * @param Audio $data + * + * @return array{type: 'input_audio', input_audio: array{ + * data: string, + * format: 'mp3'|'wav'|string, + * }} + */ + public function normalize(mixed $data, ?string $format = null, array $context = []): array + { + return [ + 'type' => 'input_audio', + 'input_audio' => [ + 'data' => $data->asBase64(), + 'format' => match ($data->getFormat()) { + 'audio/mpeg' => 'mp3', + 'audio/wav' => 'wav', + default => $data->getFormat(), + }, + ], + ]; + } +} diff --git a/src/platform/src/Contract/Normalizer/Message/Content/ImageNormalizer.php b/src/platform/src/Contract/Normalizer/Message/Content/ImageNormalizer.php new file mode 100644 index 000000000..0e32e317a --- /dev/null +++ b/src/platform/src/Contract/Normalizer/Message/Content/ImageNormalizer.php @@ -0,0 +1,46 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Contract\Normalizer\Message\Content; + +use Symfony\AI\Platform\Message\Content\Image; +use Symfony\Component\Serializer\Normalizer\NormalizerInterface; + +/** + * @author Christopher Hertel + */ +final class ImageNormalizer implements NormalizerInterface +{ + public function supportsNormalization(mixed $data, ?string $format = null, array $context = []): bool + { + return $data instanceof Image; + } + + public function getSupportedTypes(?string $format): array + { + return [ + Image::class => true, + ]; + } + + /** + * @param Image $data + * + * @return array{type: 'image_url', image_url: array{url: string}} + */ + public function normalize(mixed $data, ?string $format = null, array $context = []): array + { + return [ + 'type' => 'image_url', + 'image_url' => ['url' => $data->asDataUrl()], + ]; + } +} diff --git a/src/platform/src/Contract/Normalizer/Message/Content/ImageUrlNormalizer.php b/src/platform/src/Contract/Normalizer/Message/Content/ImageUrlNormalizer.php new file mode 100644 index 000000000..6b61cd7da --- /dev/null +++ b/src/platform/src/Contract/Normalizer/Message/Content/ImageUrlNormalizer.php @@ -0,0 +1,46 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Contract\Normalizer\Message\Content; + +use Symfony\AI\Platform\Message\Content\ImageUrl; +use Symfony\Component\Serializer\Normalizer\NormalizerInterface; + +/** + * @author Christopher Hertel + */ +final class ImageUrlNormalizer implements NormalizerInterface +{ + public function supportsNormalization(mixed $data, ?string $format = null, array $context = []): bool + { + return $data instanceof ImageUrl; + } + + public function getSupportedTypes(?string $format): array + { + return [ + ImageUrl::class => true, + ]; + } + + /** + * @param ImageUrl $data + * + * @return array{type: 'image_url', image_url: array{url: string}} + */ + public function normalize(mixed $data, ?string $format = null, array $context = []): array + { + return [ + 'type' => 'image_url', + 'image_url' => ['url' => $data->url], + ]; + } +} diff --git a/src/platform/src/Contract/Normalizer/Message/Content/TextNormalizer.php b/src/platform/src/Contract/Normalizer/Message/Content/TextNormalizer.php new file mode 100644 index 000000000..72860679d --- /dev/null +++ b/src/platform/src/Contract/Normalizer/Message/Content/TextNormalizer.php @@ -0,0 +1,43 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Contract\Normalizer\Message\Content; + +use Symfony\AI\Platform\Message\Content\Text; +use Symfony\Component\Serializer\Normalizer\NormalizerInterface; + +/** + * @author Christopher Hertel + */ +final class TextNormalizer implements NormalizerInterface +{ + public function supportsNormalization(mixed $data, ?string $format = null, array $context = []): bool + { + return $data instanceof Text; + } + + public function getSupportedTypes(?string $format): array + { + return [ + Text::class => true, + ]; + } + + /** + * @param Text $data + * + * @return array{type: 'text', text: string} + */ + public function normalize(mixed $data, ?string $format = null, array $context = []): array + { + return ['type' => 'text', 'text' => $data->text]; + } +} diff --git a/src/platform/src/Contract/Normalizer/Message/MessageBagNormalizer.php b/src/platform/src/Contract/Normalizer/Message/MessageBagNormalizer.php new file mode 100644 index 000000000..7d4229e45 --- /dev/null +++ b/src/platform/src/Contract/Normalizer/Message/MessageBagNormalizer.php @@ -0,0 +1,60 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Contract\Normalizer\Message; + +use Symfony\AI\Platform\Contract; +use Symfony\AI\Platform\Message\MessageBagInterface; +use Symfony\AI\Platform\Model; +use Symfony\Component\Serializer\Normalizer\NormalizerAwareInterface; +use Symfony\Component\Serializer\Normalizer\NormalizerAwareTrait; +use Symfony\Component\Serializer\Normalizer\NormalizerInterface; + +/** + * @author Christopher Hertel + */ +final class MessageBagNormalizer implements NormalizerInterface, NormalizerAwareInterface +{ + use NormalizerAwareTrait; + + public function supportsNormalization(mixed $data, ?string $format = null, array $context = []): bool + { + return $data instanceof MessageBagInterface; + } + + public function getSupportedTypes(?string $format): array + { + return [ + MessageBagInterface::class => true, + ]; + } + + /** + * @param MessageBagInterface $data + * + * @return array{ + * messages: array, + * model?: string, + * } + */ + public function normalize(mixed $data, ?string $format = null, array $context = []): array + { + $array = [ + 'messages' => $this->normalizer->normalize($data->getMessages(), $format, $context), + ]; + + if (isset($context[Contract::CONTEXT_MODEL]) && $context[Contract::CONTEXT_MODEL] instanceof Model) { + $array['model'] = $context[Contract::CONTEXT_MODEL]->getName(); + } + + return $array; + } +} diff --git a/src/platform/src/Contract/Normalizer/Message/SystemMessageNormalizer.php b/src/platform/src/Contract/Normalizer/Message/SystemMessageNormalizer.php new file mode 100644 index 000000000..4985b87d1 --- /dev/null +++ b/src/platform/src/Contract/Normalizer/Message/SystemMessageNormalizer.php @@ -0,0 +1,46 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Contract\Normalizer\Message; + +use Symfony\AI\Platform\Message\SystemMessage; +use Symfony\Component\Serializer\Normalizer\NormalizerInterface; + +/** + * @author Christopher Hertel + */ +final class SystemMessageNormalizer implements NormalizerInterface +{ + public function supportsNormalization(mixed $data, ?string $format = null, array $context = []): bool + { + return $data instanceof SystemMessage; + } + + public function getSupportedTypes(?string $format): array + { + return [ + SystemMessage::class => true, + ]; + } + + /** + * @param SystemMessage $data + * + * @return array{role: 'system', content: string} + */ + public function normalize(mixed $data, ?string $format = null, array $context = []): array + { + return [ + 'role' => $data->getRole()->value, + 'content' => $data->content, + ]; + } +} diff --git a/src/platform/src/Contract/Normalizer/Message/ToolCallMessageNormalizer.php b/src/platform/src/Contract/Normalizer/Message/ToolCallMessageNormalizer.php new file mode 100644 index 000000000..9661717f8 --- /dev/null +++ b/src/platform/src/Contract/Normalizer/Message/ToolCallMessageNormalizer.php @@ -0,0 +1,53 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Contract\Normalizer\Message; + +use Symfony\AI\Platform\Message\ToolCallMessage; +use Symfony\Component\Serializer\Normalizer\NormalizerAwareInterface; +use Symfony\Component\Serializer\Normalizer\NormalizerAwareTrait; +use Symfony\Component\Serializer\Normalizer\NormalizerInterface; + +/** + * @author Christopher Hertel + */ +final class ToolCallMessageNormalizer implements NormalizerInterface, NormalizerAwareInterface +{ + use NormalizerAwareTrait; + + public function supportsNormalization(mixed $data, ?string $format = null, array $context = []): bool + { + return $data instanceof ToolCallMessage; + } + + public function getSupportedTypes(?string $format): array + { + return [ + ToolCallMessage::class => true, + ]; + } + + /** + * @return array{ + * role: 'tool', + * content: string, + * tool_call_id: string, + * } + */ + public function normalize(mixed $data, ?string $format = null, array $context = []): array + { + return [ + 'role' => $data->getRole()->value, + 'content' => $this->normalizer->normalize($data->content, $format, $context), + 'tool_call_id' => $data->toolCall->id, + ]; + } +} diff --git a/src/platform/src/Contract/Normalizer/Message/UserMessageNormalizer.php b/src/platform/src/Contract/Normalizer/Message/UserMessageNormalizer.php new file mode 100644 index 000000000..af7794f5c --- /dev/null +++ b/src/platform/src/Contract/Normalizer/Message/UserMessageNormalizer.php @@ -0,0 +1,58 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Contract\Normalizer\Message; + +use Symfony\AI\Platform\Message\Content\Text; +use Symfony\AI\Platform\Message\UserMessage; +use Symfony\Component\Serializer\Normalizer\NormalizerAwareInterface; +use Symfony\Component\Serializer\Normalizer\NormalizerAwareTrait; +use Symfony\Component\Serializer\Normalizer\NormalizerInterface; + +/** + * @author Christopher Hertel + */ +final class UserMessageNormalizer implements NormalizerInterface, NormalizerAwareInterface +{ + use NormalizerAwareTrait; + + public function supportsNormalization(mixed $data, ?string $format = null, array $context = []): bool + { + return $data instanceof UserMessage; + } + + public function getSupportedTypes(?string $format): array + { + return [ + UserMessage::class => true, + ]; + } + + /** + * @param UserMessage $data + * + * @return array{role: 'assistant', content: string} + */ + public function normalize(mixed $data, ?string $format = null, array $context = []): array + { + $array = ['role' => $data->getRole()->value]; + + if (1 === \count($data->content) && $data->content[0] instanceof Text) { + $array['content'] = $data->content[0]->text; + + return $array; + } + + $array['content'] = $this->normalizer->normalize($data->content, $format, $context); + + return $array; + } +} diff --git a/src/platform/src/Contract/Normalizer/ModelContractNormalizer.php b/src/platform/src/Contract/Normalizer/ModelContractNormalizer.php new file mode 100644 index 000000000..c130d362c --- /dev/null +++ b/src/platform/src/Contract/Normalizer/ModelContractNormalizer.php @@ -0,0 +1,49 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Contract\Normalizer; + +use Symfony\AI\Platform\Contract; +use Symfony\AI\Platform\Model; +use Symfony\Component\Serializer\Normalizer\NormalizerInterface; + +/** + * @author Christopher Hertel + */ +abstract class ModelContractNormalizer implements NormalizerInterface +{ + /** + * @return class-string + */ + abstract protected function supportedDataClass(): string; + + abstract protected function supportsModel(Model $model): bool; + + public function supportsNormalization(mixed $data, ?string $format = null, array $context = []): bool + { + if (!is_a($data, $this->supportedDataClass(), true)) { + return false; + } + + if (isset($context[Contract::CONTEXT_MODEL]) && $context[Contract::CONTEXT_MODEL] instanceof Model) { + return $this->supportsModel($context[Contract::CONTEXT_MODEL]); + } + + return false; + } + + public function getSupportedTypes(?string $format): array + { + return [ + $this->supportedDataClass() => true, + ]; + } +} diff --git a/src/platform/src/Contract/Normalizer/Response/ToolCallNormalizer.php b/src/platform/src/Contract/Normalizer/Response/ToolCallNormalizer.php new file mode 100644 index 000000000..1771e3233 --- /dev/null +++ b/src/platform/src/Contract/Normalizer/Response/ToolCallNormalizer.php @@ -0,0 +1,57 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Contract\Normalizer\Response; + +use Symfony\AI\Platform\Response\ToolCall; +use Symfony\Component\Serializer\Normalizer\NormalizerInterface; + +/** + * @author Christopher Hertel + */ +final class ToolCallNormalizer implements NormalizerInterface +{ + public function supportsNormalization(mixed $data, ?string $format = null, array $context = []): bool + { + return $data instanceof ToolCall; + } + + public function getSupportedTypes(?string $format): array + { + return [ + ToolCall::class => true, + ]; + } + + /** + * @param ToolCall $data + * + * @return array{ + * id: string, + * type: 'function', + * function: array{ + * name: string, + * arguments: string + * } + * } + */ + public function normalize(mixed $data, ?string $format = null, array $context = []): array + { + return [ + 'id' => $data->id, + 'type' => 'function', + 'function' => [ + 'name' => $data->name, + 'arguments' => json_encode($data->arguments), + ], + ]; + } +} diff --git a/src/platform/src/Contract/Normalizer/ToolNormalizer.php b/src/platform/src/Contract/Normalizer/ToolNormalizer.php new file mode 100644 index 000000000..04d3e0624 --- /dev/null +++ b/src/platform/src/Contract/Normalizer/ToolNormalizer.php @@ -0,0 +1,65 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Contract\Normalizer; + +use Symfony\AI\Platform\Contract\JsonSchema\Factory; +use Symfony\AI\Platform\Tool\Tool; +use Symfony\Component\Serializer\Normalizer\NormalizerInterface; + +/** + * @phpstan-import-type JsonSchema from Factory + * + * @author Christopher Hertel + */ +class ToolNormalizer implements NormalizerInterface +{ + public function supportsNormalization(mixed $data, ?string $format = null, array $context = []): bool + { + return $data instanceof Tool; + } + + public function getSupportedTypes(?string $format): array + { + return [ + Tool::class => true, + ]; + } + + /** + * @param Tool $data + * + * @return array{ + * type: 'function', + * function: array{ + * name: string, + * description: string, + * parameters?: JsonSchema + * } + * } + */ + public function normalize(mixed $data, ?string $format = null, array $context = []): array + { + $function = [ + 'name' => $data->name, + 'description' => $data->description, + ]; + + if (isset($data->parameters)) { + $function['parameters'] = $data->parameters; + } + + return [ + 'type' => 'function', + 'function' => $function, + ]; + } +} diff --git a/src/platform/src/Exception/ContentFilterException.php b/src/platform/src/Exception/ContentFilterException.php new file mode 100644 index 000000000..95dc56200 --- /dev/null +++ b/src/platform/src/Exception/ContentFilterException.php @@ -0,0 +1,19 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Exception; + +/** + * @author Oskar Stark + */ +class ContentFilterException extends InvalidArgumentException +{ +} diff --git a/src/platform/src/Exception/ExceptionInterface.php b/src/platform/src/Exception/ExceptionInterface.php new file mode 100644 index 000000000..c5a865ae6 --- /dev/null +++ b/src/platform/src/Exception/ExceptionInterface.php @@ -0,0 +1,19 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Exception; + +/** + * @author Oskar Stark + */ +interface ExceptionInterface extends \Throwable +{ +} diff --git a/src/platform/src/Exception/InvalidArgumentException.php b/src/platform/src/Exception/InvalidArgumentException.php new file mode 100644 index 000000000..ef9bcb6ce --- /dev/null +++ b/src/platform/src/Exception/InvalidArgumentException.php @@ -0,0 +1,19 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Exception; + +/** + * @author Oskar Stark + */ +class InvalidArgumentException extends \InvalidArgumentException implements ExceptionInterface +{ +} diff --git a/src/platform/src/Exception/RuntimeException.php b/src/platform/src/Exception/RuntimeException.php new file mode 100644 index 000000000..b99a1c956 --- /dev/null +++ b/src/platform/src/Exception/RuntimeException.php @@ -0,0 +1,19 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Exception; + +/** + * @author Oskar Stark + */ +class RuntimeException extends \RuntimeException implements ExceptionInterface +{ +} diff --git a/src/platform/src/Message/AssistantMessage.php b/src/platform/src/Message/AssistantMessage.php new file mode 100644 index 000000000..adf314b06 --- /dev/null +++ b/src/platform/src/Message/AssistantMessage.php @@ -0,0 +1,39 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Message; + +use Symfony\AI\Platform\Response\ToolCall; + +/** + * @author Denis Zunke + */ +final readonly class AssistantMessage implements MessageInterface +{ + /** + * @param ?ToolCall[] $toolCalls + */ + public function __construct( + public ?string $content = null, + public ?array $toolCalls = null, + ) { + } + + public function getRole(): Role + { + return Role::Assistant; + } + + public function hasToolCalls(): bool + { + return null !== $this->toolCalls && 0 !== \count($this->toolCalls); + } +} diff --git a/src/platform/src/Message/Content/Audio.php b/src/platform/src/Message/Content/Audio.php new file mode 100644 index 000000000..b7a5800c6 --- /dev/null +++ b/src/platform/src/Message/Content/Audio.php @@ -0,0 +1,19 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Message\Content; + +/** + * @author Christopher Hertel + */ +final readonly class Audio extends File +{ +} diff --git a/src/platform/src/Message/Content/ContentInterface.php b/src/platform/src/Message/Content/ContentInterface.php new file mode 100644 index 000000000..fd3ad307f --- /dev/null +++ b/src/platform/src/Message/Content/ContentInterface.php @@ -0,0 +1,19 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Message\Content; + +/** + * @author Denis Zunke + */ +interface ContentInterface +{ +} diff --git a/src/platform/src/Message/Content/Document.php b/src/platform/src/Message/Content/Document.php new file mode 100644 index 000000000..d6bd91440 --- /dev/null +++ b/src/platform/src/Message/Content/Document.php @@ -0,0 +1,19 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Message\Content; + +/** + * @author Christopher Hertel + */ +final readonly class Document extends File +{ +} diff --git a/src/platform/src/Message/Content/DocumentUrl.php b/src/platform/src/Message/Content/DocumentUrl.php new file mode 100644 index 000000000..e44e9e562 --- /dev/null +++ b/src/platform/src/Message/Content/DocumentUrl.php @@ -0,0 +1,23 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Message\Content; + +/** + * @author Christopher Hertel + */ +final readonly class DocumentUrl implements ContentInterface +{ + public function __construct( + public string $url, + ) { + } +} diff --git a/src/platform/src/Message/Content/File.php b/src/platform/src/Message/Content/File.php new file mode 100644 index 000000000..a6d6cbcbf --- /dev/null +++ b/src/platform/src/Message/Content/File.php @@ -0,0 +1,87 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Message\Content; + +use Symfony\AI\Platform\Exception\InvalidArgumentException; +use Symfony\AI\Platform\Exception\RuntimeException; + +use function Symfony\Component\String\u; + +/** + * @author Christopher Hertel + */ +readonly class File implements ContentInterface +{ + final public function __construct( + private string|\Closure $data, + private string $format, + private ?string $path = null, + ) { + } + + public static function fromDataUrl(string $dataUrl): static + { + if (!str_starts_with($dataUrl, 'data:')) { + throw new InvalidArgumentException('Invalid audio data URL format.'); + } + + return new static( + base64_decode(u($dataUrl)->after('base64,')->toString()), + u($dataUrl)->after('data:')->before(';base64,')->toString(), + ); + } + + public static function fromFile(string $path): static + { + if (!is_readable($path)) { + throw new InvalidArgumentException(\sprintf('The file "%s" does not exist or is not readable.', $path)); + } + + return new static( + fn () => file_get_contents($path), + mime_content_type($path), + $path, + ); + } + + public function getFormat(): string + { + return $this->format; + } + + public function asBinary(): string + { + return $this->data instanceof \Closure ? ($this->data)() : $this->data; + } + + public function asBase64(): string + { + return base64_encode($this->asBinary()); + } + + public function asDataUrl(): string + { + return \sprintf('data:%s;base64,%s', $this->format, $this->asBase64()); + } + + /** + * @return resource|false + */ + public function asResource() + { + if (null === $this->path) { + throw new RuntimeException('You can only get a resource after creating fromFile.'); + } + + return fopen($this->path, 'r'); + } +} diff --git a/src/platform/src/Message/Content/Image.php b/src/platform/src/Message/Content/Image.php new file mode 100644 index 000000000..1a98399e8 --- /dev/null +++ b/src/platform/src/Message/Content/Image.php @@ -0,0 +1,19 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Message\Content; + +/** + * @author Denis Zunke + */ +final readonly class Image extends File +{ +} diff --git a/src/platform/src/Message/Content/ImageUrl.php b/src/platform/src/Message/Content/ImageUrl.php new file mode 100644 index 000000000..63420ff4a --- /dev/null +++ b/src/platform/src/Message/Content/ImageUrl.php @@ -0,0 +1,23 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Message\Content; + +/** + * @author Christopher Hertel + */ +final readonly class ImageUrl implements ContentInterface +{ + public function __construct( + public string $url, + ) { + } +} diff --git a/src/platform/src/Message/Content/Text.php b/src/platform/src/Message/Content/Text.php new file mode 100644 index 000000000..ddf371aff --- /dev/null +++ b/src/platform/src/Message/Content/Text.php @@ -0,0 +1,23 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Message\Content; + +/** + * @author Denis Zunke + */ +final readonly class Text implements ContentInterface +{ + public function __construct( + public string $text, + ) { + } +} diff --git a/src/platform/src/Message/Message.php b/src/platform/src/Message/Message.php new file mode 100644 index 000000000..6b2b9d46d --- /dev/null +++ b/src/platform/src/Message/Message.php @@ -0,0 +1,56 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Message; + +use Symfony\AI\Platform\Message\Content\ContentInterface; +use Symfony\AI\Platform\Message\Content\Text; +use Symfony\AI\Platform\Response\ToolCall; + +/** + * @author Christopher Hertel + * @author Denis Zunke + */ +final readonly class Message +{ + // Disabled by default, just a bridge to the specific messages + private function __construct() + { + } + + public static function forSystem(string $content): SystemMessage + { + return new SystemMessage($content); + } + + /** + * @param ?ToolCall[] $toolCalls + */ + public static function ofAssistant(?string $content = null, ?array $toolCalls = null): AssistantMessage + { + return new AssistantMessage($content, $toolCalls); + } + + public static function ofUser(string|ContentInterface ...$content): UserMessage + { + $content = array_map( + static fn (string|ContentInterface $entry) => \is_string($entry) ? new Text($entry) : $entry, + $content, + ); + + return new UserMessage(...$content); + } + + public static function ofToolCall(ToolCall $toolCall, string $content): ToolCallMessage + { + return new ToolCallMessage($toolCall, $content); + } +} diff --git a/src/platform/src/Message/MessageBag.php b/src/platform/src/Message/MessageBag.php new file mode 100644 index 000000000..87ad0c5b6 --- /dev/null +++ b/src/platform/src/Message/MessageBag.php @@ -0,0 +1,116 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Message; + +/** + * @final + * + * @author Christopher Hertel + */ +class MessageBag implements MessageBagInterface +{ + /** + * @var list + */ + private array $messages; + + public function __construct(MessageInterface ...$messages) + { + $this->messages = array_values($messages); + } + + public function add(MessageInterface $message): void + { + $this->messages[] = $message; + } + + /** + * @return list + */ + public function getMessages(): array + { + return $this->messages; + } + + public function getSystemMessage(): ?SystemMessage + { + foreach ($this->messages as $message) { + if ($message instanceof SystemMessage) { + return $message; + } + } + + return null; + } + + public function with(MessageInterface $message): self + { + $messages = clone $this; + $messages->add($message); + + return $messages; + } + + public function merge(MessageBagInterface $messageBag): self + { + $messages = clone $this; + $messages->messages = array_merge($messages->messages, $messageBag->getMessages()); + + return $messages; + } + + public function withoutSystemMessage(): self + { + $messages = clone $this; + $messages->messages = array_values(array_filter( + $messages->messages, + static fn (MessageInterface $message) => !$message instanceof SystemMessage, + )); + + return $messages; + } + + public function prepend(MessageInterface $message): self + { + $messages = clone $this; + $messages->messages = array_merge([$message], $messages->messages); + + return $messages; + } + + public function containsAudio(): bool + { + foreach ($this->messages as $message) { + if ($message instanceof UserMessage && $message->hasAudioContent()) { + return true; + } + } + + return false; + } + + public function containsImage(): bool + { + foreach ($this->messages as $message) { + if ($message instanceof UserMessage && $message->hasImageContent()) { + return true; + } + } + + return false; + } + + public function count(): int + { + return \count($this->messages); + } +} diff --git a/src/platform/src/Message/MessageBagInterface.php b/src/platform/src/Message/MessageBagInterface.php new file mode 100644 index 000000000..070c51969 --- /dev/null +++ b/src/platform/src/Message/MessageBagInterface.php @@ -0,0 +1,39 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Message; + +/** + * @author Oskar Stark + */ +interface MessageBagInterface extends \Countable +{ + public function add(MessageInterface $message): void; + + /** + * @return list + */ + public function getMessages(): array; + + public function getSystemMessage(): ?SystemMessage; + + public function with(MessageInterface $message): self; + + public function merge(self $messageBag): self; + + public function withoutSystemMessage(): self; + + public function prepend(MessageInterface $message): self; + + public function containsAudio(): bool; + + public function containsImage(): bool; +} diff --git a/src/platform/src/Message/MessageInterface.php b/src/platform/src/Message/MessageInterface.php new file mode 100644 index 000000000..fe1e4233a --- /dev/null +++ b/src/platform/src/Message/MessageInterface.php @@ -0,0 +1,20 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Message; + +/** + * @author Denis Zunke + */ +interface MessageInterface +{ + public function getRole(): Role; +} diff --git a/src/platform/src/Message/Role.php b/src/platform/src/Message/Role.php new file mode 100644 index 000000000..39cfa23ce --- /dev/null +++ b/src/platform/src/Message/Role.php @@ -0,0 +1,27 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Message; + +use OskarStark\Enum\Trait\Comparable; + +/** + * @author Christopher Hertel + */ +enum Role: string +{ + use Comparable; + + case System = 'system'; + case Assistant = 'assistant'; + case User = 'user'; + case ToolCall = 'tool'; +} diff --git a/src/platform/src/Message/SystemMessage.php b/src/platform/src/Message/SystemMessage.php new file mode 100644 index 000000000..9c496123e --- /dev/null +++ b/src/platform/src/Message/SystemMessage.php @@ -0,0 +1,27 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Message; + +/** + * @author Denis Zunke + */ +final readonly class SystemMessage implements MessageInterface +{ + public function __construct(public string $content) + { + } + + public function getRole(): Role + { + return Role::System; + } +} diff --git a/src/platform/src/Message/ToolCallMessage.php b/src/platform/src/Message/ToolCallMessage.php new file mode 100644 index 000000000..c0a7aac0b --- /dev/null +++ b/src/platform/src/Message/ToolCallMessage.php @@ -0,0 +1,31 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Message; + +use Symfony\AI\Platform\Response\ToolCall; + +/** + * @author Denis Zunke + */ +final readonly class ToolCallMessage implements MessageInterface +{ + public function __construct( + public ToolCall $toolCall, + public string $content, + ) { + } + + public function getRole(): Role + { + return Role::ToolCall; + } +} diff --git a/src/platform/src/Message/UserMessage.php b/src/platform/src/Message/UserMessage.php new file mode 100644 index 000000000..78c7e3f5b --- /dev/null +++ b/src/platform/src/Message/UserMessage.php @@ -0,0 +1,61 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Message; + +use Symfony\AI\Platform\Message\Content\Audio; +use Symfony\AI\Platform\Message\Content\ContentInterface; +use Symfony\AI\Platform\Message\Content\Image; +use Symfony\AI\Platform\Message\Content\ImageUrl; + +/** + * @author Denis Zunke + */ +final readonly class UserMessage implements MessageInterface +{ + /** + * @var list + */ + public array $content; + + public function __construct( + ContentInterface ...$content, + ) { + $this->content = $content; + } + + public function getRole(): Role + { + return Role::User; + } + + public function hasAudioContent(): bool + { + foreach ($this->content as $content) { + if ($content instanceof Audio) { + return true; + } + } + + return false; + } + + public function hasImageContent(): bool + { + foreach ($this->content as $content) { + if ($content instanceof Image || $content instanceof ImageUrl) { + return true; + } + } + + return false; + } +} diff --git a/src/platform/src/Model.php b/src/platform/src/Model.php new file mode 100644 index 000000000..8d572b124 --- /dev/null +++ b/src/platform/src/Model.php @@ -0,0 +1,55 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform; + +/** + * @author Christopher Hertel + */ +class Model +{ + /** + * @param string[] $capabilities + * @param array $options + */ + public function __construct( + private readonly string $name, + private readonly array $capabilities = [], + private readonly array $options = [], + ) { + } + + public function getName(): string + { + return $this->name; + } + + /** + * @return string[] + */ + public function getCapabilities(): array + { + return $this->capabilities; + } + + public function supports(string $capability): bool + { + return \in_array($capability, $this->capabilities, true); + } + + /** + * @return array + */ + public function getOptions(): array + { + return $this->options; + } +} diff --git a/src/platform/src/ModelClientInterface.php b/src/platform/src/ModelClientInterface.php new file mode 100644 index 000000000..76b6ae875 --- /dev/null +++ b/src/platform/src/ModelClientInterface.php @@ -0,0 +1,28 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform; + +use Symfony\Contracts\HttpClient\ResponseInterface; + +/** + * @author Christopher Hertel + */ +interface ModelClientInterface +{ + public function supports(Model $model): bool; + + /** + * @param array $payload + * @param array $options + */ + public function request(Model $model, array|string $payload, array $options = []): ResponseInterface; +} diff --git a/src/platform/src/Platform.php b/src/platform/src/Platform.php new file mode 100644 index 000000000..b45b92348 --- /dev/null +++ b/src/platform/src/Platform.php @@ -0,0 +1,90 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform; + +use Symfony\AI\Platform\Exception\RuntimeException; +use Symfony\AI\Platform\Response\AsyncResponse; +use Symfony\AI\Platform\Response\ResponseInterface; +use Symfony\Contracts\HttpClient\ResponseInterface as HttpResponse; + +/** + * @author Christopher Hertel + */ +final class Platform implements PlatformInterface +{ + /** + * @var ModelClientInterface[] + */ + private readonly array $modelClients; + + /** + * @var ResponseConverterInterface[] + */ + private readonly array $responseConverter; + + /** + * @param iterable $modelClients + * @param iterable $responseConverter + */ + public function __construct( + iterable $modelClients, + iterable $responseConverter, + private ?Contract $contract = null, + ) { + $this->contract = $contract ?? Contract::create(); + $this->modelClients = $modelClients instanceof \Traversable ? iterator_to_array($modelClients) : $modelClients; + $this->responseConverter = $responseConverter instanceof \Traversable ? iterator_to_array($responseConverter) : $responseConverter; + } + + public function request(Model $model, array|string|object $input, array $options = []): ResponseInterface + { + $payload = $this->contract->createRequestPayload($model, $input); + $options = array_merge($model->getOptions(), $options); + + if (isset($options['tools'])) { + $options['tools'] = $this->contract->createToolOption($options['tools'], $model); + } + + $response = $this->doRequest($model, $payload, $options); + + return $this->convertResponse($model, $response, $options); + } + + /** + * @param array $payload + * @param array $options + */ + private function doRequest(Model $model, array|string $payload, array $options = []): HttpResponse + { + foreach ($this->modelClients as $modelClient) { + if ($modelClient->supports($model)) { + return $modelClient->request($model, $payload, $options); + } + } + + throw new RuntimeException('No response factory registered for model "'.$model::class.'" with given input.'); + } + + /** + * @param array $options + */ + private function convertResponse(Model $model, HttpResponse $response, array $options): ResponseInterface + { + foreach ($this->responseConverter as $responseConverter) { + if ($responseConverter->supports($model)) { + return new AsyncResponse($responseConverter, $response, $options); + } + } + + throw new RuntimeException('No response converter registered for model "'.$model::class.'" with given input.'); + } +} diff --git a/src/platform/src/PlatformInterface.php b/src/platform/src/PlatformInterface.php new file mode 100644 index 000000000..ba0208615 --- /dev/null +++ b/src/platform/src/PlatformInterface.php @@ -0,0 +1,26 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform; + +use Symfony\AI\Platform\Response\ResponseInterface; + +/** + * @author Christopher Hertel + */ +interface PlatformInterface +{ + /** + * @param array|string|object $input + * @param array $options + */ + public function request(Model $model, array|string|object $input, array $options = []): ResponseInterface; +} diff --git a/src/platform/src/Response/AsyncResponse.php b/src/platform/src/Response/AsyncResponse.php new file mode 100644 index 000000000..eb9aef107 --- /dev/null +++ b/src/platform/src/Response/AsyncResponse.php @@ -0,0 +1,83 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Response; + +use Symfony\AI\Platform\Response\Exception\RawResponseAlreadySetException; +use Symfony\AI\Platform\Response\Metadata\MetadataAwareTrait; +use Symfony\AI\Platform\ResponseConverterInterface; +use Symfony\Contracts\HttpClient\ResponseInterface as HttpResponse; + +/** + * @author Christopher Hertel + */ +final class AsyncResponse implements ResponseInterface +{ + use MetadataAwareTrait; + + private bool $isConverted = false; + private ResponseInterface $convertedResponse; + + /** + * @param array $options + */ + public function __construct( + private readonly ResponseConverterInterface $responseConverter, + private readonly HttpResponse $response, + private readonly array $options = [], + ) { + } + + public function getContent(): string|iterable|object|null + { + return $this->unwrap()->getContent(); + } + + public function getRawResponse(): HttpResponse + { + return $this->response; + } + + public function setRawResponse(HttpResponse $rawResponse): void + { + // Empty by design as the raw response is already set in the constructor and must only be set once + throw new RawResponseAlreadySetException(); + } + + public function unwrap(): ResponseInterface + { + if (!$this->isConverted) { + $this->convertedResponse = $this->responseConverter->convert($this->response, $this->options); + + if (null === $this->convertedResponse->getRawResponse()) { + // Fallback to set the raw response when it was not handled by the response converter itself + $this->convertedResponse->setRawResponse($this->response); + } + + $this->isConverted = true; + } + + return $this->convertedResponse; + } + + /** + * @param array $arguments + */ + public function __call(string $name, array $arguments): mixed + { + return $this->unwrap()->{$name}(...$arguments); + } + + public function __get(string $name): mixed + { + return $this->unwrap()->{$name}; + } +} diff --git a/src/platform/src/Response/BaseResponse.php b/src/platform/src/Response/BaseResponse.php new file mode 100644 index 000000000..78013e917 --- /dev/null +++ b/src/platform/src/Response/BaseResponse.php @@ -0,0 +1,23 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Response; + +use Symfony\AI\Platform\Response\Metadata\MetadataAwareTrait; + +/** + * @author Denis Zunke + */ +abstract class BaseResponse implements ResponseInterface +{ + use MetadataAwareTrait; + use RawResponseAwareTrait; +} diff --git a/src/platform/src/Response/BinaryResponse.php b/src/platform/src/Response/BinaryResponse.php new file mode 100644 index 000000000..460df8ef3 --- /dev/null +++ b/src/platform/src/Response/BinaryResponse.php @@ -0,0 +1,45 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Response; + +use Symfony\AI\Platform\Exception\RuntimeException; + +/** + * @author Christopher Hertel + */ +final class BinaryResponse extends BaseResponse +{ + public function __construct( + public string $data, + public ?string $mimeType = null, + ) { + } + + public function getContent(): string + { + return $this->data; + } + + public function toBase64(): string + { + return base64_encode($this->data); + } + + public function toDataUri(): string + { + if (null === $this->mimeType) { + throw new RuntimeException('Mime type is not set.'); + } + + return 'data:'.$this->mimeType.';base64,'.$this->toBase64(); + } +} diff --git a/src/platform/src/Response/Choice.php b/src/platform/src/Response/Choice.php new file mode 100644 index 000000000..46e1e3883 --- /dev/null +++ b/src/platform/src/Response/Choice.php @@ -0,0 +1,50 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Response; + +/** + * @author Christopher Hertel + */ +final readonly class Choice +{ + /** + * @param ToolCall[] $toolCalls + */ + public function __construct( + private ?string $content = null, + private array $toolCalls = [], + ) { + } + + public function getContent(): ?string + { + return $this->content; + } + + public function hasContent(): bool + { + return null !== $this->content; + } + + /** + * @return ToolCall[] + */ + public function getToolCalls(): array + { + return $this->toolCalls; + } + + public function hasToolCall(): bool + { + return 0 !== \count($this->toolCalls); + } +} diff --git a/src/platform/src/Response/ChoiceResponse.php b/src/platform/src/Response/ChoiceResponse.php new file mode 100644 index 000000000..d81f34957 --- /dev/null +++ b/src/platform/src/Response/ChoiceResponse.php @@ -0,0 +1,42 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Response; + +use Symfony\AI\Platform\Exception\InvalidArgumentException; + +/** + * @author Christopher Hertel + */ +final class ChoiceResponse extends BaseResponse +{ + /** + * @var Choice[] + */ + private readonly array $choices; + + public function __construct(Choice ...$choices) + { + if (0 === \count($choices)) { + throw new InvalidArgumentException('Response must have at least one choice.'); + } + + $this->choices = $choices; + } + + /** + * @return Choice[] + */ + public function getContent(): array + { + return $this->choices; + } +} diff --git a/src/platform/src/Response/Exception/RawResponseAlreadySetException.php b/src/platform/src/Response/Exception/RawResponseAlreadySetException.php new file mode 100644 index 000000000..a8a6c1ca8 --- /dev/null +++ b/src/platform/src/Response/Exception/RawResponseAlreadySetException.php @@ -0,0 +1,25 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Response\Exception; + +use Symfony\AI\Platform\Exception\RuntimeException; + +/** + * @author Denis Zunke + */ +final class RawResponseAlreadySetException extends RuntimeException +{ + public function __construct() + { + parent::__construct('The raw response was already set.'); + } +} diff --git a/src/platform/src/Response/Metadata/Metadata.php b/src/platform/src/Response/Metadata/Metadata.php new file mode 100644 index 000000000..ea6f04f1c --- /dev/null +++ b/src/platform/src/Response/Metadata/Metadata.php @@ -0,0 +1,108 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Response\Metadata; + +/** + * @implements \IteratorAggregate + * @implements \ArrayAccess + * + * @author Denis Zunke + */ +class Metadata implements \JsonSerializable, \Countable, \IteratorAggregate, \ArrayAccess +{ + /** + * @var array + */ + private array $metadata = []; + + /** + * @param array $metadata + */ + public function __construct(array $metadata = []) + { + $this->set($metadata); + } + + /** + * @return array + */ + public function all(): array + { + return $this->metadata; + } + + /** + * @param array $metadata + */ + public function set(array $metadata): void + { + $this->metadata = $metadata; + } + + public function add(string $key, mixed $value): void + { + $this->metadata[$key] = $value; + } + + public function has(string $key): bool + { + return \array_key_exists($key, $this->metadata); + } + + public function get(string $key, mixed $default = null): mixed + { + return $this->metadata[$key] ?? $default; + } + + public function remove(string $key): void + { + unset($this->metadata[$key]); + } + + /** + * @return array + */ + public function jsonSerialize(): array + { + return $this->all(); + } + + public function count(): int + { + return \count($this->metadata); + } + + public function getIterator(): \Traversable + { + return new \ArrayIterator($this->metadata); + } + + public function offsetExists(mixed $offset): bool + { + return $this->has((string) $offset); + } + + public function offsetGet(mixed $offset): mixed + { + return $this->get((string) $offset); + } + + public function offsetSet(mixed $offset, mixed $value): void + { + $this->add((string) $offset, $value); + } + + public function offsetUnset(mixed $offset): void + { + $this->remove((string) $offset); + } +} diff --git a/src/platform/src/Response/Metadata/MetadataAwareTrait.php b/src/platform/src/Response/Metadata/MetadataAwareTrait.php new file mode 100644 index 000000000..ed3fffa62 --- /dev/null +++ b/src/platform/src/Response/Metadata/MetadataAwareTrait.php @@ -0,0 +1,25 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Response\Metadata; + +/** + * @author Denis Zunke + */ +trait MetadataAwareTrait +{ + private ?Metadata $metadata = null; + + public function getMetadata(): Metadata + { + return $this->metadata ??= new Metadata(); + } +} diff --git a/src/platform/src/Response/ObjectResponse.php b/src/platform/src/Response/ObjectResponse.php new file mode 100644 index 000000000..f98ba36f9 --- /dev/null +++ b/src/platform/src/Response/ObjectResponse.php @@ -0,0 +1,34 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Response; + +/** + * @author Christopher Hertel + */ +final class ObjectResponse extends BaseResponse +{ + /** + * @param object|array $structuredOutput + */ + public function __construct( + private readonly object|array $structuredOutput, + ) { + } + + /** + * @return object|array + */ + public function getContent(): object|array + { + return $this->structuredOutput; + } +} diff --git a/src/platform/src/Response/RawResponseAwareTrait.php b/src/platform/src/Response/RawResponseAwareTrait.php new file mode 100644 index 000000000..029ab77b7 --- /dev/null +++ b/src/platform/src/Response/RawResponseAwareTrait.php @@ -0,0 +1,37 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Response; + +use Symfony\AI\Platform\Response\Exception\RawResponseAlreadySetException; +use Symfony\Contracts\HttpClient\ResponseInterface as SymfonyHttpResponse; + +/** + * @author Denis Zunke + */ +trait RawResponseAwareTrait +{ + protected ?SymfonyHttpResponse $rawResponse = null; + + public function setRawResponse(SymfonyHttpResponse $rawResponse): void + { + if (null !== $this->rawResponse) { + throw new RawResponseAlreadySetException(); + } + + $this->rawResponse = $rawResponse; + } + + public function getRawResponse(): ?SymfonyHttpResponse + { + return $this->rawResponse; + } +} diff --git a/src/platform/src/Response/ResponseInterface.php b/src/platform/src/Response/ResponseInterface.php new file mode 100644 index 000000000..c8738240f --- /dev/null +++ b/src/platform/src/Response/ResponseInterface.php @@ -0,0 +1,37 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Response; + +use Symfony\AI\Platform\Response\Exception\RawResponseAlreadySetException; +use Symfony\AI\Platform\Response\Metadata\Metadata; +use Symfony\Contracts\HttpClient\ResponseInterface as SymfonyHttpResponse; + +/** + * @author Christopher Hertel + * @author Denis Zunke + */ +interface ResponseInterface +{ + /** + * @return string|iterable|object|null + */ + public function getContent(): string|iterable|object|null; + + public function getMetadata(): Metadata; + + public function getRawResponse(): ?SymfonyHttpResponse; + + /** + * @throws RawResponseAlreadySetException if the response is tried to be set more than once + */ + public function setRawResponse(SymfonyHttpResponse $rawResponse): void; +} diff --git a/src/platform/src/Response/StreamResponse.php b/src/platform/src/Response/StreamResponse.php new file mode 100644 index 000000000..684c4d8d0 --- /dev/null +++ b/src/platform/src/Response/StreamResponse.php @@ -0,0 +1,28 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Response; + +/** + * @author Christopher Hertel + */ +final class StreamResponse extends BaseResponse +{ + public function __construct( + private readonly \Generator $generator, + ) { + } + + public function getContent(): \Generator + { + yield from $this->generator; + } +} diff --git a/src/platform/src/Response/TextResponse.php b/src/platform/src/Response/TextResponse.php new file mode 100644 index 000000000..7bb9047c0 --- /dev/null +++ b/src/platform/src/Response/TextResponse.php @@ -0,0 +1,28 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Response; + +/** + * @author Christopher Hertel + */ +final class TextResponse extends BaseResponse +{ + public function __construct( + private readonly string $content, + ) { + } + + public function getContent(): string + { + return $this->content; + } +} diff --git a/src/platform/src/Response/ToolCall.php b/src/platform/src/Response/ToolCall.php new file mode 100644 index 000000000..8633632fc --- /dev/null +++ b/src/platform/src/Response/ToolCall.php @@ -0,0 +1,50 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Response; + +/** + * @author Christopher Hertel + */ +final readonly class ToolCall implements \JsonSerializable +{ + /** + * @param array $arguments + */ + public function __construct( + public string $id, + public string $name, + public array $arguments = [], + ) { + } + + /** + * @return array{ + * id: string, + * type: 'function', + * function: array{ + * name: string, + * arguments: string + * } + * } + */ + public function jsonSerialize(): array + { + return [ + 'id' => $this->id, + 'type' => 'function', + 'function' => [ + 'name' => $this->name, + 'arguments' => json_encode($this->arguments), + ], + ]; + } +} diff --git a/src/platform/src/Response/ToolCallResponse.php b/src/platform/src/Response/ToolCallResponse.php new file mode 100644 index 000000000..fcdc94c3b --- /dev/null +++ b/src/platform/src/Response/ToolCallResponse.php @@ -0,0 +1,42 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Response; + +use Symfony\AI\Platform\Exception\InvalidArgumentException; + +/** + * @author Christopher Hertel + */ +final class ToolCallResponse extends BaseResponse +{ + /** + * @var ToolCall[] + */ + private readonly array $toolCalls; + + public function __construct(ToolCall ...$toolCalls) + { + if (0 === \count($toolCalls)) { + throw new InvalidArgumentException('Response must have at least one tool call.'); + } + + $this->toolCalls = $toolCalls; + } + + /** + * @return ToolCall[] + */ + public function getContent(): array + { + return $this->toolCalls; + } +} diff --git a/src/platform/src/Response/VectorResponse.php b/src/platform/src/Response/VectorResponse.php new file mode 100644 index 000000000..e63c898c4 --- /dev/null +++ b/src/platform/src/Response/VectorResponse.php @@ -0,0 +1,38 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Response; + +use Symfony\AI\Platform\Vector\Vector; + +/** + * @author Christopher Hertel + */ +final class VectorResponse extends BaseResponse +{ + /** + * @var Vector[] + */ + private readonly array $vectors; + + public function __construct(Vector ...$vector) + { + $this->vectors = $vector; + } + + /** + * @return Vector[] + */ + public function getContent(): array + { + return $this->vectors; + } +} diff --git a/src/platform/src/ResponseConverterInterface.php b/src/platform/src/ResponseConverterInterface.php new file mode 100644 index 000000000..9bdfd0551 --- /dev/null +++ b/src/platform/src/ResponseConverterInterface.php @@ -0,0 +1,28 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform; + +use Symfony\AI\Platform\Response\ResponseInterface as LlmResponse; +use Symfony\Contracts\HttpClient\ResponseInterface as HttpResponse; + +/** + * @author Christopher Hertel + */ +interface ResponseConverterInterface +{ + public function supports(Model $model): bool; + + /** + * @param array $options + */ + public function convert(HttpResponse $response, array $options = []): LlmResponse; +} diff --git a/src/platform/src/Tool/ExecutionReference.php b/src/platform/src/Tool/ExecutionReference.php new file mode 100644 index 000000000..0a68c005a --- /dev/null +++ b/src/platform/src/Tool/ExecutionReference.php @@ -0,0 +1,24 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tool; + +/** + * @author Christopher Hertel + */ +final class ExecutionReference +{ + public function __construct( + public string $class, + public string $method = '__invoke', + ) { + } +} diff --git a/src/platform/src/Tool/Tool.php b/src/platform/src/Tool/Tool.php new file mode 100644 index 000000000..826a04f4c --- /dev/null +++ b/src/platform/src/Tool/Tool.php @@ -0,0 +1,33 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tool; + +use Symfony\AI\Platform\Contract\JsonSchema\Factory; + +/** + * @phpstan-import-type JsonSchema from Factory + * + * @author Christopher Hertel + */ +final readonly class Tool +{ + /** + * @param JsonSchema|null $parameters + */ + public function __construct( + public ExecutionReference $reference, + public string $name, + public string $description, + public ?array $parameters = null, + ) { + } +} diff --git a/src/platform/src/Vector/NullVector.php b/src/platform/src/Vector/NullVector.php new file mode 100644 index 000000000..f714efaff --- /dev/null +++ b/src/platform/src/Vector/NullVector.php @@ -0,0 +1,30 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Vector; + +use Symfony\AI\Platform\Exception\RuntimeException; + +/** + * @author Oskar Stark + */ +final class NullVector implements VectorInterface +{ + public function getData(): array + { + throw new RuntimeException('getData() method cannot be called on a NullVector.'); + } + + public function getDimensions(): int + { + throw new RuntimeException('getDimensions() method cannot be called on a NullVector.'); + } +} diff --git a/src/platform/src/Vector/Vector.php b/src/platform/src/Vector/Vector.php new file mode 100644 index 000000000..f09f08a6a --- /dev/null +++ b/src/platform/src/Vector/Vector.php @@ -0,0 +1,57 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Vector; + +use Symfony\AI\Platform\Exception\InvalidArgumentException; + +/** + * @author Christopher Hertel + */ +final class Vector implements VectorInterface +{ + /** + * @param list $data + */ + public function __construct( + private readonly array $data, + private ?int $dimensions = null, + ) { + if (null !== $dimensions && $dimensions !== \count($data)) { + throw new InvalidArgumentException('Vector must have '.$dimensions.' dimensions'); + } + + if (0 === \count($data)) { + throw new InvalidArgumentException('Vector must have at least one dimension'); + } + + if (\is_int($dimensions) && \count($data) !== $dimensions) { + throw new InvalidArgumentException('Vector must have '.$dimensions.' dimensions'); + } + + if (null === $this->dimensions) { + $this->dimensions = \count($data); + } + } + + /** + * @return list + */ + public function getData(): array + { + return $this->data; + } + + public function getDimensions(): int + { + return $this->dimensions; + } +} diff --git a/src/platform/src/Vector/VectorInterface.php b/src/platform/src/Vector/VectorInterface.php new file mode 100644 index 000000000..78ea1a933 --- /dev/null +++ b/src/platform/src/Vector/VectorInterface.php @@ -0,0 +1,25 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Vector; + +/** + * @author Oskar Stark + */ +interface VectorInterface +{ + /** + * @return list + */ + public function getData(): array; + + public function getDimensions(): int; +} diff --git a/src/platform/tests/Bridge/Anthropic/ResponseConverterTest.php b/src/platform/tests/Bridge/Anthropic/ResponseConverterTest.php new file mode 100644 index 000000000..58c989719 --- /dev/null +++ b/src/platform/tests/Bridge/Anthropic/ResponseConverterTest.php @@ -0,0 +1,52 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Bridge\Anthropic; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Bridge\Anthropic\ResponseConverter; +use Symfony\AI\Platform\Response\ToolCall; +use Symfony\AI\Platform\Response\ToolCallResponse; +use Symfony\Component\HttpClient\MockHttpClient; +use Symfony\Component\HttpClient\Response\JsonMockResponse; + +#[CoversClass(ResponseConverter::class)] +#[Small] +#[UsesClass(ToolCall::class)] +#[UsesClass(ToolCallResponse::class)] +final class ResponseConverterTest extends TestCase +{ + public function testConvertThrowsExceptionWhenContentIsToolUseAndLacksText(): void + { + $httpClient = new MockHttpClient(new JsonMockResponse([ + 'content' => [ + [ + 'type' => 'tool_use', + 'id' => 'toolu_01UM4PcTjC1UDiorSXVHSVFM', + 'name' => 'xxx_tool', + 'input' => ['action' => 'get_data'], + ], + ], + ])); + $httpResponse = $httpClient->request('POST', 'https://api.anthropic.com/v1/messages'); + $handler = new ResponseConverter(); + + $response = $handler->convert($httpResponse); + self::assertInstanceOf(ToolCallResponse::class, $response); + self::assertCount(1, $response->getContent()); + self::assertSame('toolu_01UM4PcTjC1UDiorSXVHSVFM', $response->getContent()[0]->id); + self::assertSame('xxx_tool', $response->getContent()[0]->name); + self::assertSame(['action' => 'get_data'], $response->getContent()[0]->arguments); + } +} diff --git a/src/platform/tests/Bridge/Bedrock/Nova/ContractTest.php b/src/platform/tests/Bridge/Bedrock/Nova/ContractTest.php new file mode 100644 index 000000000..c152cdfbb --- /dev/null +++ b/src/platform/tests/Bridge/Bedrock/Nova/ContractTest.php @@ -0,0 +1,145 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Bridge\Bedrock\Nova; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\DataProvider; +use PHPUnit\Framework\Attributes\Medium; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Bridge\Bedrock\Nova\Contract\AssistantMessageNormalizer; +use Symfony\AI\Platform\Bridge\Bedrock\Nova\Contract\MessageBagNormalizer; +use Symfony\AI\Platform\Bridge\Bedrock\Nova\Contract\ToolCallMessageNormalizer; +use Symfony\AI\Platform\Bridge\Bedrock\Nova\Contract\ToolNormalizer; +use Symfony\AI\Platform\Bridge\Bedrock\Nova\Contract\UserMessageNormalizer; +use Symfony\AI\Platform\Bridge\Bedrock\Nova\Nova; +use Symfony\AI\Platform\Contract; +use Symfony\AI\Platform\Message\AssistantMessage; +use Symfony\AI\Platform\Message\Message; +use Symfony\AI\Platform\Message\MessageBag; +use Symfony\AI\Platform\Message\SystemMessage; +use Symfony\AI\Platform\Message\ToolCallMessage; +use Symfony\AI\Platform\Message\UserMessage; +use Symfony\AI\Platform\Response\ToolCall; + +#[Medium] +#[CoversClass(AssistantMessageNormalizer::class)] +#[CoversClass(MessageBagNormalizer::class)] +#[CoversClass(ToolCallMessageNormalizer::class)] +#[CoversClass(ToolNormalizer::class)] +#[CoversClass(UserMessageNormalizer::class)] +#[UsesClass(UserMessage::class)] +#[UsesClass(AssistantMessage::class)] +#[UsesClass(ToolCallMessage::class)] +#[UsesClass(SystemMessage::class)] +#[UsesClass(MessageBag::class)] +final class ContractTest extends TestCase +{ + #[Test] + #[DataProvider('provideMessageBag')] + public function testConvert(MessageBag $bag, array $expected): void + { + $contract = Contract::create( + new AssistantMessageNormalizer(), + new MessageBagNormalizer(), + new ToolCallMessageNormalizer(), + new ToolNormalizer(), + new UserMessageNormalizer(), + ); + + self::assertEquals($expected, $contract->createRequestPayload(new Nova(), $bag)); + } + + /** + * @return iterable + */ + public static function provideMessageBag(): iterable + { + yield 'simple text' => [ + new MessageBag(Message::ofUser('Write a story about a magic backpack.')), + [ + 'messages' => [ + ['role' => 'user', 'content' => [['text' => 'Write a story about a magic backpack.']]], + ], + ], + ]; + + yield 'with assistant message' => [ + new MessageBag( + Message::ofUser('Hello'), + Message::ofAssistant('Great to meet you. What would you like to know?'), + Message::ofUser('I have two dogs in my house. How many paws are in my house?'), + ), + [ + 'messages' => [ + ['role' => 'user', 'content' => [['text' => 'Hello']]], + ['role' => 'assistant', 'content' => [['text' => 'Great to meet you. What would you like to know?']]], + ['role' => 'user', 'content' => [['text' => 'I have two dogs in my house. How many paws are in my house?']]], + ], + ], + ]; + + yield 'with system messages' => [ + new MessageBag( + Message::forSystem('You are a cat. Your name is Neko.'), + Message::ofUser('Hello there'), + ), + [ + 'system' => [['text' => 'You are a cat. Your name is Neko.']], + 'messages' => [ + ['role' => 'user', 'content' => [['text' => 'Hello there']]], + ], + ], + ]; + + yield 'with tool use' => [ + new MessageBag( + Message::ofUser('Hello there, what is the time?'), + Message::ofAssistant(toolCalls: [new ToolCall('123456', 'clock', [])]), + Message::ofToolCall(new ToolCall('123456', 'clock', []), '2023-10-01T10:00:00+00:00'), + Message::ofAssistant('It is 10:00 AM.'), + ), + [ + 'messages' => [ + ['role' => 'user', 'content' => [['text' => 'Hello there, what is the time?']]], + [ + 'role' => 'assistant', + 'content' => [ + [ + 'toolUse' => [ + 'toolUseId' => '123456', + 'name' => 'clock', + 'input' => new \stdClass(), + ], + ], + ], + ], + [ + 'role' => 'user', + 'content' => [ + [ + 'toolResult' => [ + 'toolUseId' => '123456', + 'content' => [ + ['json' => '2023-10-01T10:00:00+00:00'], + ], + ], + ], + ], + ], + ['role' => 'assistant', 'content' => [['text' => 'It is 10:00 AM.']]], + ], + ], + ]; + } +} diff --git a/src/platform/tests/Bridge/Google/Contract/AssistantMessageNormalizerTest.php b/src/platform/tests/Bridge/Google/Contract/AssistantMessageNormalizerTest.php new file mode 100644 index 000000000..27101c318 --- /dev/null +++ b/src/platform/tests/Bridge/Google/Contract/AssistantMessageNormalizerTest.php @@ -0,0 +1,61 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Bridge\Google\Contract; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Bridge\Google\Contract\AssistantMessageNormalizer; +use Symfony\AI\Platform\Bridge\Google\Gemini; +use Symfony\AI\Platform\Contract; +use Symfony\AI\Platform\Message\AssistantMessage; +use Symfony\AI\Platform\Model; + +#[Small] +#[CoversClass(AssistantMessageNormalizer::class)] +#[UsesClass(Gemini::class)] +#[UsesClass(AssistantMessage::class)] +#[UsesClass(Model::class)] +final class AssistantMessageNormalizerTest extends TestCase +{ + #[Test] + public function supportsNormalization(): void + { + $normalizer = new AssistantMessageNormalizer(); + + self::assertTrue($normalizer->supportsNormalization(new AssistantMessage('Hello'), context: [ + Contract::CONTEXT_MODEL => new Gemini(), + ])); + self::assertFalse($normalizer->supportsNormalization('not an assistant message')); + } + + #[Test] + public function getSupportedTypes(): void + { + $normalizer = new AssistantMessageNormalizer(); + + self::assertSame([AssistantMessage::class => true], $normalizer->getSupportedTypes(null)); + } + + #[Test] + public function normalize(): void + { + $normalizer = new AssistantMessageNormalizer(); + $message = new AssistantMessage('Great to meet you. What would you like to know?'); + + $normalized = $normalizer->normalize($message); + + self::assertSame([['text' => 'Great to meet you. What would you like to know?']], $normalized); + } +} diff --git a/src/platform/tests/Bridge/Google/Contract/MessageBagNormalizerTest.php b/src/platform/tests/Bridge/Google/Contract/MessageBagNormalizerTest.php new file mode 100644 index 000000000..c4b8da61d --- /dev/null +++ b/src/platform/tests/Bridge/Google/Contract/MessageBagNormalizerTest.php @@ -0,0 +1,157 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Bridge\Google\Contract; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\DataProvider; +use PHPUnit\Framework\Attributes\Medium; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Bridge\Google\Contract\AssistantMessageNormalizer; +use Symfony\AI\Platform\Bridge\Google\Contract\MessageBagNormalizer; +use Symfony\AI\Platform\Bridge\Google\Contract\UserMessageNormalizer; +use Symfony\AI\Platform\Bridge\Google\Gemini; +use Symfony\AI\Platform\Contract; +use Symfony\AI\Platform\Message\AssistantMessage; +use Symfony\AI\Platform\Message\Content\Image; +use Symfony\AI\Platform\Message\Message; +use Symfony\AI\Platform\Message\MessageBag; +use Symfony\AI\Platform\Message\MessageBagInterface; +use Symfony\AI\Platform\Message\UserMessage; +use Symfony\AI\Platform\Model; +use Symfony\Component\Serializer\Normalizer\NormalizerInterface; + +#[Medium] +#[CoversClass(MessageBagNormalizer::class)] +#[CoversClass(UserMessageNormalizer::class)] +#[CoversClass(AssistantMessageNormalizer::class)] +#[UsesClass(Model::class)] +#[UsesClass(Gemini::class)] +#[UsesClass(MessageBag::class)] +#[UsesClass(UserMessage::class)] +#[UsesClass(AssistantMessage::class)] +final class MessageBagNormalizerTest extends TestCase +{ + #[Test] + public function supportsNormalization(): void + { + $normalizer = new MessageBagNormalizer(); + + self::assertTrue($normalizer->supportsNormalization(new MessageBag(), context: [ + Contract::CONTEXT_MODEL => new Gemini(), + ])); + self::assertFalse($normalizer->supportsNormalization('not a message bag')); + } + + #[Test] + public function getSupportedTypes(): void + { + $normalizer = new MessageBagNormalizer(); + + $expected = [ + MessageBagInterface::class => true, + ]; + + self::assertSame($expected, $normalizer->getSupportedTypes(null)); + } + + #[Test] + #[DataProvider('provideMessageBagData')] + public function normalize(MessageBag $bag, array $expected): void + { + $normalizer = new MessageBagNormalizer(); + + // Set up the inner normalizers + $userMessageNormalizer = new UserMessageNormalizer(); + $assistantMessageNormalizer = new AssistantMessageNormalizer(); + + // Mock a normalizer that delegates to the appropriate concrete normalizer + $mockNormalizer = $this->createMock(NormalizerInterface::class); + $mockNormalizer->method('normalize') + ->willReturnCallback(function ($message) use ($userMessageNormalizer, $assistantMessageNormalizer): ?array { + if ($message instanceof UserMessage) { + return $userMessageNormalizer->normalize($message); + } + if ($message instanceof AssistantMessage) { + return $assistantMessageNormalizer->normalize($message); + } + + return null; + }); + + $normalizer->setNormalizer($mockNormalizer); + + $normalized = $normalizer->normalize($bag); + + self::assertEquals($expected, $normalized); + } + + /** + * @return iterable + */ + public static function provideMessageBagData(): iterable + { + yield 'simple text' => [ + new MessageBag(Message::ofUser('Write a story about a magic backpack.')), + [ + 'contents' => [ + ['role' => 'user', 'parts' => [['text' => 'Write a story about a magic backpack.']]], + ], + ], + ]; + + yield 'text with image' => [ + new MessageBag( + Message::ofUser('Tell me about this instrument', Image::fromFile(\dirname(__DIR__, 4).'/Fixture/image.jpg')) + ), + [ + 'contents' => [ + ['role' => 'user', 'parts' => [ + ['text' => 'Tell me about this instrument'], + ['inline_data' => ['mime_type' => 'image/jpeg', 'data' => base64_encode(file_get_contents(\dirname(__DIR__, 4).'/Fixture/image.jpg'))]], + ]], + ], + ], + ]; + + yield 'with assistant message' => [ + new MessageBag( + Message::ofUser('Hello'), + Message::ofAssistant('Great to meet you. What would you like to know?'), + Message::ofUser('I have two dogs in my house. How many paws are in my house?'), + ), + [ + 'contents' => [ + ['role' => 'user', 'parts' => [['text' => 'Hello']]], + ['role' => 'model', 'parts' => [['text' => 'Great to meet you. What would you like to know?']]], + ['role' => 'user', 'parts' => [['text' => 'I have two dogs in my house. How many paws are in my house?']]], + ], + ], + ]; + + yield 'with system messages' => [ + new MessageBag( + Message::forSystem('You are a cat. Your name is Neko.'), + Message::ofUser('Hello there'), + ), + [ + 'contents' => [ + ['role' => 'user', 'parts' => [['text' => 'Hello there']]], + ], + 'system_instruction' => [ + 'parts' => ['text' => 'You are a cat. Your name is Neko.'], + ], + ], + ]; + } +} diff --git a/src/platform/tests/Bridge/Google/Contract/UserMessageNormalizerTest.php b/src/platform/tests/Bridge/Google/Contract/UserMessageNormalizerTest.php new file mode 100644 index 000000000..1ce885edb --- /dev/null +++ b/src/platform/tests/Bridge/Google/Contract/UserMessageNormalizerTest.php @@ -0,0 +1,83 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Bridge\Google\Contract; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Bridge\Google\Contract\UserMessageNormalizer; +use Symfony\AI\Platform\Bridge\Google\Gemini; +use Symfony\AI\Platform\Contract; +use Symfony\AI\Platform\Message\Content\File; +use Symfony\AI\Platform\Message\Content\Image; +use Symfony\AI\Platform\Message\Content\Text; +use Symfony\AI\Platform\Message\UserMessage; + +#[Small] +#[CoversClass(UserMessageNormalizer::class)] +#[UsesClass(Gemini::class)] +#[UsesClass(UserMessage::class)] +#[UsesClass(Text::class)] +#[UsesClass(File::class)] +final class UserMessageNormalizerTest extends TestCase +{ + #[Test] + public function supportsNormalization(): void + { + $normalizer = new UserMessageNormalizer(); + + self::assertTrue($normalizer->supportsNormalization(new UserMessage(new Text('Hello')), context: [ + Contract::CONTEXT_MODEL => new Gemini(), + ])); + self::assertFalse($normalizer->supportsNormalization('not a user message')); + } + + #[Test] + public function getSupportedTypes(): void + { + $normalizer = new UserMessageNormalizer(); + + self::assertSame([UserMessage::class => true], $normalizer->getSupportedTypes(null)); + } + + #[Test] + public function normalizeTextContent(): void + { + $normalizer = new UserMessageNormalizer(); + $message = new UserMessage(new Text('Write a story about a magic backpack.')); + + $normalized = $normalizer->normalize($message); + + self::assertSame([['text' => 'Write a story about a magic backpack.']], $normalized); + } + + #[Test] + public function normalizeImageContent(): void + { + $normalizer = new UserMessageNormalizer(); + $imageContent = Image::fromFile(\dirname(__DIR__, 4).'/Fixture/image.jpg'); + $message = new UserMessage(new Text('Tell me about this instrument'), $imageContent); + + $normalized = $normalizer->normalize($message); + + self::assertCount(2, $normalized); + self::assertSame(['text' => 'Tell me about this instrument'], $normalized[0]); + self::assertArrayHasKey('inline_data', $normalized[1]); + self::assertSame('image/jpeg', $normalized[1]['inline_data']['mime_type']); + self::assertNotEmpty($normalized[1]['inline_data']['data']); + + // Verify that the base64 data string starts correctly for a JPEG + self::assertStringStartsWith('/9j/', $normalized[1]['inline_data']['data']); + } +} diff --git a/src/platform/tests/Bridge/HuggingFace/ModelClientTest.php b/src/platform/tests/Bridge/HuggingFace/ModelClientTest.php new file mode 100644 index 000000000..5f14c29d9 --- /dev/null +++ b/src/platform/tests/Bridge/HuggingFace/ModelClientTest.php @@ -0,0 +1,151 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Bridge\HuggingFace; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\DataProvider; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Bridge\HuggingFace\Contract\FileNormalizer; +use Symfony\AI\Platform\Bridge\HuggingFace\Contract\MessageBagNormalizer; +use Symfony\AI\Platform\Bridge\HuggingFace\ModelClient; +use Symfony\AI\Platform\Bridge\HuggingFace\Task; +use Symfony\AI\Platform\Contract; +use Symfony\AI\Platform\Message\Content\Text; +use Symfony\AI\Platform\Message\MessageBag; +use Symfony\AI\Platform\Message\UserMessage; +use Symfony\AI\Platform\Model; +use Symfony\Component\HttpClient\MockHttpClient; + +#[CoversClass(ModelClient::class)] +#[Small] +#[UsesClass(Model::class)] +final class ModelClientTest extends TestCase +{ + #[DataProvider('urlTestCases')] + public function testGetUrlForDifferentInputsAndTasks(?string $task, string $expectedUrl): void + { + $reflection = new \ReflectionClass(ModelClient::class); + $getUrlMethod = $reflection->getMethod('getUrl'); + $getUrlMethod->setAccessible(true); + + $model = new Model('test-model'); + $httpClient = new MockHttpClient(); + $modelClient = new ModelClient($httpClient, 'test-provider', 'test-api-key'); + + $actualUrl = $getUrlMethod->invoke($modelClient, $model, $task); + + self::assertEquals($expectedUrl, $actualUrl); + } + + public static function urlTestCases(): \Iterator + { + $messageBag = new MessageBag(); + $messageBag->add(new UserMessage(new Text('Test message'))); + yield 'string input' => [ + 'task' => null, + 'expectedUrl' => 'https://router.huggingface.co/test-provider/models/test-model', + ]; + yield 'array input' => [ + 'task' => null, + 'expectedUrl' => 'https://router.huggingface.co/test-provider/models/test-model', + ]; + yield 'image input' => [ + 'task' => null, + 'expectedUrl' => 'https://router.huggingface.co/test-provider/models/test-model', + ]; + yield 'feature extraction' => [ + 'task' => Task::FEATURE_EXTRACTION, + 'expectedUrl' => 'https://router.huggingface.co/test-provider/pipeline/feature-extraction/test-model', + ]; + yield 'message bag' => [ + 'task' => Task::CHAT_COMPLETION, + 'expectedUrl' => 'https://router.huggingface.co/test-provider/models/test-model/v1/chat/completions', + ]; + } + + #[DataProvider('payloadTestCases')] + public function testGetPayloadForDifferentInputsAndTasks(object|array|string $input, array $options, array $expectedKeys, array $expectedValues = []): void + { + // Contract handling first + $contract = Contract::create( + new FileNormalizer(), + new MessageBagNormalizer() + ); + + $payload = $contract->createRequestPayload(new Model('test-model'), $input); + + $reflection = new \ReflectionClass(ModelClient::class); + $getPayloadMethod = $reflection->getMethod('getPayload'); + $getPayloadMethod->setAccessible(true); + + $httpClient = new MockHttpClient(); + $modelClient = new ModelClient($httpClient, 'test-provider', 'test-api-key'); + + $actual = $getPayloadMethod->invoke($modelClient, $payload, $options); + + // Check that expected keys exist + foreach ($expectedKeys as $key) { + self::assertArrayHasKey($key, $actual); + } + + // Check expected values if specified + foreach ($expectedValues as $path => $value) { + $keys = explode('.', $path); + $current = $actual; + foreach ($keys as $key) { + self::assertArrayHasKey($key, $current); + $current = $current[$key]; + } + + self::assertEquals($value, $current); + } + } + + public static function payloadTestCases(): \Iterator + { + yield 'string input' => [ + 'input' => 'Hello world', + 'options' => [], + 'expectedKeys' => ['headers', 'json'], + 'expectedValues' => [ + 'headers.Content-Type' => 'application/json', + 'json.inputs' => 'Hello world', + ], + ]; + + yield 'array input' => [ + 'input' => ['text' => 'Hello world'], + 'options' => ['temperature' => 0.7], + 'expectedKeys' => ['headers', 'json'], + 'expectedValues' => [ + 'headers.Content-Type' => 'application/json', + 'json.inputs' => ['text' => 'Hello world'], + 'json.parameters.temperature' => 0.7, + ], + ]; + + $messageBag = new MessageBag(); + $messageBag->add(new UserMessage(new Text('Test message'))); + + yield 'message bag' => [ + 'input' => $messageBag, + 'options' => ['max_tokens' => 100], + 'expectedKeys' => ['headers', 'json'], + 'expectedValues' => [ + 'headers.Content-Type' => 'application/json', + 'json.max_tokens' => 100, + ], + ]; + } +} diff --git a/src/platform/tests/Bridge/Meta/LlamaPromptConverterTest.php b/src/platform/tests/Bridge/Meta/LlamaPromptConverterTest.php new file mode 100644 index 000000000..7539981ad --- /dev/null +++ b/src/platform/tests/Bridge/Meta/LlamaPromptConverterTest.php @@ -0,0 +1,146 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Bridge\Meta; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\DataProvider; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Bridge\Meta\LlamaPromptConverter; +use Symfony\AI\Platform\Message\AssistantMessage; +use Symfony\AI\Platform\Message\Content\ImageUrl; +use Symfony\AI\Platform\Message\Content\Text; +use Symfony\AI\Platform\Message\Message; +use Symfony\AI\Platform\Message\MessageBag; +use Symfony\AI\Platform\Message\SystemMessage; +use Symfony\AI\Platform\Message\UserMessage; + +#[CoversClass(LlamaPromptConverter::class)] +#[Small] +#[UsesClass(AssistantMessage::class)] +#[UsesClass(ImageUrl::class)] +#[UsesClass(Text::class)] +#[UsesClass(Message::class)] +#[UsesClass(MessageBag::class)] +#[UsesClass(SystemMessage::class)] +#[UsesClass(UserMessage::class)] +final class LlamaPromptConverterTest extends TestCase +{ + #[Test] + public function convertMessages(): void + { + $messageBag = new MessageBag(); + foreach (self::provideMessages() as $message) { + $messageBag->add($message[1]); + } + + self::assertSame(<<<|start_header_id|>system<|end_header_id|> + + You are a helpful chatbot.<|eot_id|> + + <|start_header_id|>user<|end_header_id|> + + Hello, how are you?<|eot_id|> + + <|start_header_id|>user<|end_header_id|> + + Hello, how are you? + What is your name?<|eot_id|> + + <|start_header_id|>user<|end_header_id|> + + Hello, how are you? + What is your name? + https://example.com/image.jpg<|eot_id|> + + <|start_header_id|>assistant<|end_header_id|> + + I am an assistant.<|eot_id|> + + <|start_header_id|>assistant<|end_header_id|> + EXPECTED, + (new LlamaPromptConverter())->convertToPrompt($messageBag) + ); + } + + #[Test] + #[DataProvider('provideMessages')] + public function convertMessage(string $expected, UserMessage|SystemMessage|AssistantMessage $message): void + { + self::assertSame( + $expected, + (new LlamaPromptConverter())->convertMessage($message) + ); + } + + /** + * @return iterable + */ + public static function provideMessages(): iterable + { + yield 'System message' => [ + <<<|start_header_id|>system<|end_header_id|> + + You are a helpful chatbot.<|eot_id|> + SYSTEM, + Message::forSystem('You are a helpful chatbot.'), + ]; + + yield 'UserMessage' => [ + <<user<|end_header_id|> + + Hello, how are you?<|eot_id|> + USER, + Message::ofUser('Hello, how are you?'), + ]; + + yield 'UserMessage with two texts' => [ + <<user<|end_header_id|> + + Hello, how are you? + What is your name?<|eot_id|> + USER, + Message::ofUser('Hello, how are you?', 'What is your name?'), + ]; + + yield 'UserMessage with two texts and one image' => [ + <<user<|end_header_id|> + + Hello, how are you? + What is your name? + https://example.com/image.jpg<|eot_id|> + USER, + Message::ofUser('Hello, how are you?', 'What is your name?', new ImageUrl('https://example.com/image.jpg')), + ]; + + yield 'AssistantMessage' => [ + <<assistant<|end_header_id|> + + I am an assistant.<|eot_id|> + ASSISTANT, + new AssistantMessage('I am an assistant.'), + ]; + + yield 'AssistantMessage with null content' => [ + '', + new AssistantMessage(), + ]; + } +} diff --git a/src/platform/tests/Bridge/OpenAI/DallE/Base64ImageTest.php b/src/platform/tests/Bridge/OpenAI/DallE/Base64ImageTest.php new file mode 100644 index 000000000..615f474f4 --- /dev/null +++ b/src/platform/tests/Bridge/OpenAI/DallE/Base64ImageTest.php @@ -0,0 +1,41 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Bridge\OpenAI\DallE; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Bridge\OpenAI\DallE\Base64Image; + +#[CoversClass(Base64Image::class)] +#[Small] +final class Base64ImageTest extends TestCase +{ + #[Test] + public function itCreatesBase64Image(): void + { + $emptyPixel = 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=='; + $base64Image = new Base64Image($emptyPixel); + + self::assertSame($emptyPixel, $base64Image->encodedImage); + } + + #[Test] + public function itThrowsExceptionWhenBase64ImageIsEmpty(): void + { + self::expectException(\InvalidArgumentException::class); + self::expectExceptionMessage('The base64 encoded image generated must be given.'); + + new Base64Image(''); + } +} diff --git a/src/platform/tests/Bridge/OpenAI/DallE/ImageResponseTest.php b/src/platform/tests/Bridge/OpenAI/DallE/ImageResponseTest.php new file mode 100644 index 000000000..75e9ce84a --- /dev/null +++ b/src/platform/tests/Bridge/OpenAI/DallE/ImageResponseTest.php @@ -0,0 +1,63 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Bridge\OpenAI\DallE; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Bridge\OpenAI\DallE\Base64Image; +use Symfony\AI\Platform\Bridge\OpenAI\DallE\ImageResponse; +use Symfony\AI\Platform\Bridge\OpenAI\DallE\UrlImage; + +#[CoversClass(ImageResponse::class)] +#[UsesClass(Base64Image::class)] +#[UsesClass(UrlImage::class)] +#[Small] +final class ImageResponseTest extends TestCase +{ + #[Test] + public function itCreatesImagesResponse(): void + { + $base64Image = new Base64Image('iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=='); + $generatedImagesResponse = new ImageResponse(null, $base64Image); + + self::assertNull($generatedImagesResponse->revisedPrompt); + self::assertCount(1, $generatedImagesResponse->getContent()); + self::assertSame($base64Image, $generatedImagesResponse->getContent()[0]); + } + + #[Test] + public function itCreatesImagesResponseWithRevisedPrompt(): void + { + $base64Image = new Base64Image('iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=='); + $generatedImagesResponse = new ImageResponse('revised prompt', $base64Image); + + self::assertSame('revised prompt', $generatedImagesResponse->revisedPrompt); + self::assertCount(1, $generatedImagesResponse->getContent()); + self::assertSame($base64Image, $generatedImagesResponse->getContent()[0]); + } + + #[Test] + public function itIsCreatableWithMultipleImages(): void + { + $image1 = new UrlImage('https://example'); + $image2 = new UrlImage('https://example2'); + + $generatedImagesResponse = new ImageResponse(null, $image1, $image2); + + self::assertCount(2, $generatedImagesResponse->getContent()); + self::assertSame($image1, $generatedImagesResponse->getContent()[0]); + self::assertSame($image2, $generatedImagesResponse->getContent()[1]); + } +} diff --git a/src/platform/tests/Bridge/OpenAI/DallE/ModelClientTest.php b/src/platform/tests/Bridge/OpenAI/DallE/ModelClientTest.php new file mode 100644 index 000000000..02e851171 --- /dev/null +++ b/src/platform/tests/Bridge/OpenAI/DallE/ModelClientTest.php @@ -0,0 +1,99 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Bridge\OpenAI\DallE; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Bridge\OpenAI\DallE; +use Symfony\AI\Platform\Bridge\OpenAI\DallE\Base64Image; +use Symfony\AI\Platform\Bridge\OpenAI\DallE\ImageResponse; +use Symfony\AI\Platform\Bridge\OpenAI\DallE\ModelClient; +use Symfony\AI\Platform\Bridge\OpenAI\DallE\UrlImage; +use Symfony\Component\HttpClient\MockHttpClient; +use Symfony\Component\HttpClient\Response\MockResponse; +use Symfony\Contracts\HttpClient\ResponseInterface as HttpResponse; + +#[CoversClass(ModelClient::class)] +#[UsesClass(DallE::class)] +#[UsesClass(UrlImage::class)] +#[UsesClass(Base64Image::class)] +#[UsesClass(ImageResponse::class)] +#[Small] +final class ModelClientTest extends TestCase +{ + #[Test] + public function itIsSupportingTheCorrectModel(): void + { + $modelClient = new ModelClient(new MockHttpClient(), 'sk-api-key'); + + self::assertTrue($modelClient->supports(new DallE())); + } + + #[Test] + public function itIsExecutingTheCorrectRequest(): void + { + $responseCallback = static function (string $method, string $url, array $options): HttpResponse { + self::assertSame('POST', $method); + self::assertSame('https://api.openai.com/v1/images/generations', $url); + self::assertSame('Authorization: Bearer sk-api-key', $options['normalized_headers']['authorization'][0]); + self::assertSame('{"n":1,"response_format":"url","model":"dall-e-2","prompt":"foo"}', $options['body']); + + return new MockResponse(); + }; + $httpClient = new MockHttpClient([$responseCallback]); + $modelClient = new ModelClient($httpClient, 'sk-api-key'); + $modelClient->request(new DallE(), 'foo', ['n' => 1, 'response_format' => 'url']); + } + + #[Test] + public function itIsConvertingTheResponse(): void + { + $httpResponse = self::createStub(HttpResponse::class); + $httpResponse->method('toArray')->willReturn([ + 'data' => [ + ['url' => 'https://example.com/image.jpg'], + ], + ]); + + $modelClient = new ModelClient(new MockHttpClient(), 'sk-api-key'); + $response = $modelClient->convert($httpResponse, ['response_format' => 'url']); + + self::assertCount(1, $response->getContent()); + self::assertInstanceOf(UrlImage::class, $response->getContent()[0]); + self::assertSame('https://example.com/image.jpg', $response->getContent()[0]->url); + } + + #[Test] + public function itIsConvertingTheResponseWithRevisedPrompt(): void + { + $emptyPixel = 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=='; + + $httpResponse = self::createStub(HttpResponse::class); + $httpResponse->method('toArray')->willReturn([ + 'data' => [ + ['b64_json' => $emptyPixel, 'revised_prompt' => 'revised prompt'], + ], + ]); + + $modelClient = new ModelClient(new MockHttpClient(), 'sk-api-key'); + $response = $modelClient->convert($httpResponse, ['response_format' => 'b64_json']); + + self::assertInstanceOf(ImageResponse::class, $response); + self::assertCount(1, $response->getContent()); + self::assertInstanceOf(Base64Image::class, $response->getContent()[0]); + self::assertSame($emptyPixel, $response->getContent()[0]->encodedImage); + self::assertSame('revised prompt', $response->revisedPrompt); + } +} diff --git a/src/platform/tests/Bridge/OpenAI/DallE/UrlImageTest.php b/src/platform/tests/Bridge/OpenAI/DallE/UrlImageTest.php new file mode 100644 index 000000000..d71d43200 --- /dev/null +++ b/src/platform/tests/Bridge/OpenAI/DallE/UrlImageTest.php @@ -0,0 +1,40 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Bridge\OpenAI\DallE; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Bridge\OpenAI\DallE\UrlImage; + +#[CoversClass(UrlImage::class)] +#[Small] +final class UrlImageTest extends TestCase +{ + #[Test] + public function itCreatesUrlImage(): void + { + $urlImage = new UrlImage('https://example.com/image.jpg'); + + self::assertSame('https://example.com/image.jpg', $urlImage->url); + } + + #[Test] + public function itThrowsExceptionWhenUrlIsEmpty(): void + { + self::expectException(\InvalidArgumentException::class); + self::expectExceptionMessage('The image url must be given.'); + + new UrlImage(''); + } +} diff --git a/src/platform/tests/Bridge/OpenAI/DallETest.php b/src/platform/tests/Bridge/OpenAI/DallETest.php new file mode 100644 index 000000000..0597b20c5 --- /dev/null +++ b/src/platform/tests/Bridge/OpenAI/DallETest.php @@ -0,0 +1,41 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Bridge\OpenAI; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Bridge\OpenAI\DallE; + +#[CoversClass(DallE::class)] +#[Small] +final class DallETest extends TestCase +{ + #[Test] + public function itCreatesDallEWithDefaultSettings(): void + { + $dallE = new DallE(); + + self::assertSame(DallE::DALL_E_2, $dallE->getName()); + self::assertSame([], $dallE->getOptions()); + } + + #[Test] + public function itCreatesDallEWithCustomSettings(): void + { + $dallE = new DallE(DallE::DALL_E_3, ['response_format' => 'base64', 'n' => 2]); + + self::assertSame(DallE::DALL_E_3, $dallE->getName()); + self::assertSame(['response_format' => 'base64', 'n' => 2], $dallE->getOptions()); + } +} diff --git a/src/platform/tests/Bridge/OpenAI/Embeddings/ResponseConverterTest.php b/src/platform/tests/Bridge/OpenAI/Embeddings/ResponseConverterTest.php new file mode 100644 index 000000000..48175b06c --- /dev/null +++ b/src/platform/tests/Bridge/OpenAI/Embeddings/ResponseConverterTest.php @@ -0,0 +1,67 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Bridge\OpenAI\Embeddings; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Bridge\OpenAI\Embeddings\ResponseConverter; +use Symfony\AI\Platform\Response\VectorResponse; +use Symfony\AI\Platform\Vector\Vector; +use Symfony\Contracts\HttpClient\ResponseInterface; + +#[CoversClass(ResponseConverter::class)] +#[Small] +#[UsesClass(Vector::class)] +#[UsesClass(VectorResponse::class)] +class ResponseConverterTest extends TestCase +{ + #[Test] + public function itConvertsAResponseToAVectorResponse(): void + { + $response = $this->createStub(ResponseInterface::class); + $response + ->method('toArray') + ->willReturn(json_decode($this->getEmbeddingStub(), true)); + + $vectorResponse = (new ResponseConverter())->convert($response); + $convertedContent = $vectorResponse->getContent(); + + self::assertCount(2, $convertedContent); + + self::assertSame([0.3, 0.4, 0.4], $convertedContent[0]->getData()); + self::assertSame([0.0, 0.0, 0.2], $convertedContent[1]->getData()); + } + + private function getEmbeddingStub(): string + { + return <<<'JSON' + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [0.3, 0.4, 0.4] + }, + { + "object": "embedding", + "index": 1, + "embedding": [0.0, 0.0, 0.2] + } + ] + } + JSON; + } +} diff --git a/src/platform/tests/Bridge/OpenAI/GPT/ResponseConverterTest.php b/src/platform/tests/Bridge/OpenAI/GPT/ResponseConverterTest.php new file mode 100644 index 000000000..71478bd64 --- /dev/null +++ b/src/platform/tests/Bridge/OpenAI/GPT/ResponseConverterTest.php @@ -0,0 +1,192 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Bridge\OpenAI\GPT; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Bridge\OpenAI\GPT\ResponseConverter; +use Symfony\AI\Platform\Exception\ContentFilterException; +use Symfony\AI\Platform\Exception\RuntimeException; +use Symfony\AI\Platform\Response\Choice; +use Symfony\AI\Platform\Response\ChoiceResponse; +use Symfony\AI\Platform\Response\TextResponse; +use Symfony\AI\Platform\Response\ToolCall; +use Symfony\AI\Platform\Response\ToolCallResponse; +use Symfony\Contracts\HttpClient\Exception\ClientExceptionInterface; +use Symfony\Contracts\HttpClient\ResponseInterface; + +#[CoversClass(ResponseConverter::class)] +#[Small] +#[UsesClass(Choice::class)] +#[UsesClass(ChoiceResponse::class)] +#[UsesClass(TextResponse::class)] +#[UsesClass(ToolCall::class)] +#[UsesClass(ToolCallResponse::class)] +class ResponseConverterTest extends TestCase +{ + public function testConvertTextResponse(): void + { + $converter = new ResponseConverter(); + $httpResponse = self::createMock(ResponseInterface::class); + $httpResponse->method('toArray')->willReturn([ + 'choices' => [ + [ + 'message' => [ + 'role' => 'assistant', + 'content' => 'Hello world', + ], + 'finish_reason' => 'stop', + ], + ], + ]); + + $response = $converter->convert($httpResponse); + + self::assertInstanceOf(TextResponse::class, $response); + self::assertSame('Hello world', $response->getContent()); + } + + public function testConvertToolCallResponse(): void + { + $converter = new ResponseConverter(); + $httpResponse = self::createMock(ResponseInterface::class); + $httpResponse->method('toArray')->willReturn([ + 'choices' => [ + [ + 'message' => [ + 'role' => 'assistant', + 'content' => null, + 'tool_calls' => [ + [ + 'id' => 'call_123', + 'type' => 'function', + 'function' => [ + 'name' => 'test_function', + 'arguments' => '{"arg1": "value1"}', + ], + ], + ], + ], + 'finish_reason' => 'tool_calls', + ], + ], + ]); + + $response = $converter->convert($httpResponse); + + self::assertInstanceOf(ToolCallResponse::class, $response); + $toolCalls = $response->getContent(); + self::assertCount(1, $toolCalls); + self::assertSame('call_123', $toolCalls[0]->id); + self::assertSame('test_function', $toolCalls[0]->name); + self::assertSame(['arg1' => 'value1'], $toolCalls[0]->arguments); + } + + public function testConvertMultipleChoices(): void + { + $converter = new ResponseConverter(); + $httpResponse = self::createMock(ResponseInterface::class); + $httpResponse->method('toArray')->willReturn([ + 'choices' => [ + [ + 'message' => [ + 'role' => 'assistant', + 'content' => 'Choice 1', + ], + 'finish_reason' => 'stop', + ], + [ + 'message' => [ + 'role' => 'assistant', + 'content' => 'Choice 2', + ], + 'finish_reason' => 'stop', + ], + ], + ]); + + $response = $converter->convert($httpResponse); + + self::assertInstanceOf(ChoiceResponse::class, $response); + $choices = $response->getContent(); + self::assertCount(2, $choices); + self::assertSame('Choice 1', $choices[0]->getContent()); + self::assertSame('Choice 2', $choices[1]->getContent()); + } + + public function testContentFilterException(): void + { + $converter = new ResponseConverter(); + $httpResponse = self::createMock(ResponseInterface::class); + + $httpResponse->expects($this->exactly(2)) + ->method('toArray') + ->willReturnCallback(function ($throw = true) { + if ($throw) { + throw new class extends \Exception implements ClientExceptionInterface { + public function getResponse(): ResponseInterface + { + throw new RuntimeException('Not implemented'); + } + }; + } + + return [ + 'error' => [ + 'code' => 'content_filter', + 'message' => 'Content was filtered', + ], + ]; + }); + + self::expectException(ContentFilterException::class); + self::expectExceptionMessage('Content was filtered'); + + $converter->convert($httpResponse); + } + + public function testThrowsExceptionWhenNoChoices(): void + { + $converter = new ResponseConverter(); + $httpResponse = self::createMock(ResponseInterface::class); + $httpResponse->method('toArray')->willReturn([]); + + self::expectException(RuntimeException::class); + self::expectExceptionMessage('Response does not contain choices'); + + $converter->convert($httpResponse); + } + + public function testThrowsExceptionForUnsupportedFinishReason(): void + { + $converter = new ResponseConverter(); + $httpResponse = self::createMock(ResponseInterface::class); + $httpResponse->method('toArray')->willReturn([ + 'choices' => [ + [ + 'message' => [ + 'role' => 'assistant', + 'content' => 'Test content', + ], + 'finish_reason' => 'unsupported_reason', + ], + ], + ]); + + self::expectException(RuntimeException::class); + self::expectExceptionMessage('Unsupported finish reason "unsupported_reason"'); + + $converter->convert($httpResponse); + } +} diff --git a/src/platform/tests/Bridge/OpenAI/TokenOutputProcessorTest.php b/src/platform/tests/Bridge/OpenAI/TokenOutputProcessorTest.php new file mode 100644 index 000000000..f2ea7595d --- /dev/null +++ b/src/platform/tests/Bridge/OpenAI/TokenOutputProcessorTest.php @@ -0,0 +1,155 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Bridge\OpenAI; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Agent\Output; +use Symfony\AI\Platform\Bridge\OpenAI\TokenOutputProcessor; +use Symfony\AI\Platform\Message\MessageBagInterface; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\Response\Metadata\Metadata; +use Symfony\AI\Platform\Response\ResponseInterface; +use Symfony\AI\Platform\Response\StreamResponse; +use Symfony\AI\Platform\Response\TextResponse; +use Symfony\Contracts\HttpClient\ResponseInterface as SymfonyHttpResponse; + +#[CoversClass(TokenOutputProcessor::class)] +#[UsesClass(Output::class)] +#[UsesClass(TextResponse::class)] +#[UsesClass(StreamResponse::class)] +#[UsesClass(Metadata::class)] +#[Small] +final class TokenOutputProcessorTest extends TestCase +{ + #[Test] + public function itHandlesStreamResponsesWithoutProcessing(): void + { + $processor = new TokenOutputProcessor(); + $streamResponse = new StreamResponse((static function () { yield 'test'; })()); + $output = $this->createOutput($streamResponse); + + $processor->processOutput($output); + + $metadata = $output->response->getMetadata(); + self::assertCount(0, $metadata); + } + + #[Test] + public function itDoesNothingWithoutRawResponse(): void + { + $processor = new TokenOutputProcessor(); + $textResponse = new TextResponse('test'); + $output = $this->createOutput($textResponse); + + $processor->processOutput($output); + + $metadata = $output->response->getMetadata(); + self::assertCount(0, $metadata); + } + + #[Test] + public function itAddsRemainingTokensToMetadata(): void + { + $processor = new TokenOutputProcessor(); + $textResponse = new TextResponse('test'); + + $textResponse->setRawResponse($this->createRawResponse()); + + $output = $this->createOutput($textResponse); + + $processor->processOutput($output); + + $metadata = $output->response->getMetadata(); + self::assertCount(1, $metadata); + self::assertSame(1000, $metadata->get('remaining_tokens')); + } + + #[Test] + public function itAddsUsageTokensToMetadata(): void + { + $processor = new TokenOutputProcessor(); + $textResponse = new TextResponse('test'); + + $rawResponse = $this->createRawResponse([ + 'usage' => [ + 'prompt_tokens' => 10, + 'completion_tokens' => 20, + 'total_tokens' => 30, + ], + ]); + + $textResponse->setRawResponse($rawResponse); + + $output = $this->createOutput($textResponse); + + $processor->processOutput($output); + + $metadata = $output->response->getMetadata(); + self::assertCount(4, $metadata); + self::assertSame(1000, $metadata->get('remaining_tokens')); + self::assertSame(10, $metadata->get('prompt_tokens')); + self::assertSame(20, $metadata->get('completion_tokens')); + self::assertSame(30, $metadata->get('total_tokens')); + } + + #[Test] + public function itHandlesMissingUsageFields(): void + { + $processor = new TokenOutputProcessor(); + $textResponse = new TextResponse('test'); + + $rawResponse = $this->createRawResponse([ + 'usage' => [ + // Missing some fields + 'prompt_tokens' => 10, + ], + ]); + + $textResponse->setRawResponse($rawResponse); + + $output = $this->createOutput($textResponse); + + $processor->processOutput($output); + + $metadata = $output->response->getMetadata(); + self::assertCount(4, $metadata); + self::assertSame(1000, $metadata->get('remaining_tokens')); + self::assertSame(10, $metadata->get('prompt_tokens')); + self::assertNull($metadata->get('completion_tokens')); + self::assertNull($metadata->get('total_tokens')); + } + + private function createRawResponse(array $data = []): SymfonyHttpResponse + { + $rawResponse = self::createStub(SymfonyHttpResponse::class); + $rawResponse->method('getHeaders')->willReturn([ + 'x-ratelimit-remaining-tokens' => ['1000'], + ]); + $rawResponse->method('toArray')->willReturn($data); + + return $rawResponse; + } + + private function createOutput(ResponseInterface $response): Output + { + return new Output( + self::createStub(Model::class), + $response, + self::createStub(MessageBagInterface::class), + [], + ); + } +} diff --git a/src/platform/tests/Contract/JsonSchema/Attribute/ToolParameterTest.php b/src/platform/tests/Contract/JsonSchema/Attribute/ToolParameterTest.php new file mode 100644 index 000000000..4a2b69d5a --- /dev/null +++ b/src/platform/tests/Contract/JsonSchema/Attribute/ToolParameterTest.php @@ -0,0 +1,263 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Contract\JsonSchema\Attribute; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Contract\JsonSchema\Attribute\With; +use Webmozart\Assert\InvalidArgumentException; + +#[CoversClass(With::class)] +final class ToolParameterTest extends TestCase +{ + #[Test] + public function validEnum(): void + { + $enum = ['value1', 'value2']; + $toolParameter = new With(enum: $enum); + self::assertSame($enum, $toolParameter->enum); + } + + #[Test] + public function invalidEnumContainsNonString(): void + { + self::expectException(InvalidArgumentException::class); + $enum = ['value1', 2]; + new With(enum: $enum); + } + + #[Test] + public function validConstString(): void + { + $const = 'constant value'; + $toolParameter = new With(const: $const); + self::assertSame($const, $toolParameter->const); + } + + #[Test] + public function invalidConstEmptyString(): void + { + self::expectException(InvalidArgumentException::class); + $const = ' '; + new With(const: $const); + } + + #[Test] + public function validPattern(): void + { + $pattern = '/^[a-z]+$/'; + $toolParameter = new With(pattern: $pattern); + self::assertSame($pattern, $toolParameter->pattern); + } + + #[Test] + public function invalidPatternEmptyString(): void + { + self::expectException(InvalidArgumentException::class); + $pattern = ' '; + new With(pattern: $pattern); + } + + #[Test] + public function validMinLength(): void + { + $minLength = 5; + $toolParameter = new With(minLength: $minLength); + self::assertSame($minLength, $toolParameter->minLength); + } + + #[Test] + public function invalidMinLengthNegative(): void + { + self::expectException(InvalidArgumentException::class); + new With(minLength: -1); + } + + #[Test] + public function validMinLengthAndMaxLength(): void + { + $minLength = 5; + $maxLength = 10; + $toolParameter = new With(minLength: $minLength, maxLength: $maxLength); + self::assertSame($minLength, $toolParameter->minLength); + self::assertSame($maxLength, $toolParameter->maxLength); + } + + #[Test] + public function invalidMaxLengthLessThanMinLength(): void + { + self::expectException(InvalidArgumentException::class); + new With(minLength: 10, maxLength: 5); + } + + #[Test] + public function validMinimum(): void + { + $minimum = 0; + $toolParameter = new With(minimum: $minimum); + self::assertSame($minimum, $toolParameter->minimum); + } + + #[Test] + public function invalidMinimumNegative(): void + { + self::expectException(InvalidArgumentException::class); + new With(minimum: -1); + } + + #[Test] + public function validMultipleOf(): void + { + $multipleOf = 5; + $toolParameter = new With(multipleOf: $multipleOf); + self::assertSame($multipleOf, $toolParameter->multipleOf); + } + + #[Test] + public function invalidMultipleOfNegative(): void + { + self::expectException(InvalidArgumentException::class); + new With(multipleOf: -5); + } + + #[Test] + public function validExclusiveMinimumAndMaximum(): void + { + $exclusiveMinimum = 1; + $exclusiveMaximum = 10; + $toolParameter = new With(exclusiveMinimum: $exclusiveMinimum, exclusiveMaximum: $exclusiveMaximum); + self::assertSame($exclusiveMinimum, $toolParameter->exclusiveMinimum); + self::assertSame($exclusiveMaximum, $toolParameter->exclusiveMaximum); + } + + #[Test] + public function invalidExclusiveMaximumLessThanExclusiveMinimum(): void + { + self::expectException(InvalidArgumentException::class); + new With(exclusiveMinimum: 10, exclusiveMaximum: 5); + } + + #[Test] + public function validMinItemsAndMaxItems(): void + { + $minItems = 1; + $maxItems = 5; + $toolParameter = new With(minItems: $minItems, maxItems: $maxItems); + self::assertSame($minItems, $toolParameter->minItems); + self::assertSame($maxItems, $toolParameter->maxItems); + } + + #[Test] + public function invalidMaxItemsLessThanMinItems(): void + { + self::expectException(InvalidArgumentException::class); + new With(minItems: 5, maxItems: 1); + } + + #[Test] + public function validUniqueItemsTrue(): void + { + $toolParameter = new With(uniqueItems: true); + self::assertTrue($toolParameter->uniqueItems); + } + + #[Test] + public function invalidUniqueItemsFalse(): void + { + self::expectException(InvalidArgumentException::class); + new With(uniqueItems: false); + } + + #[Test] + public function validMinContainsAndMaxContains(): void + { + $minContains = 1; + $maxContains = 3; + $toolParameter = new With(minContains: $minContains, maxContains: $maxContains); + self::assertSame($minContains, $toolParameter->minContains); + self::assertSame($maxContains, $toolParameter->maxContains); + } + + #[Test] + public function invalidMaxContainsLessThanMinContains(): void + { + self::expectException(InvalidArgumentException::class); + new With(minContains: 3, maxContains: 1); + } + + #[Test] + public function validRequired(): void + { + $toolParameter = new With(required: true); + self::assertTrue($toolParameter->required); + } + + #[Test] + public function validMinPropertiesAndMaxProperties(): void + { + $minProperties = 1; + $maxProperties = 5; + $toolParameter = new With(minProperties: $minProperties, maxProperties: $maxProperties); + self::assertSame($minProperties, $toolParameter->minProperties); + self::assertSame($maxProperties, $toolParameter->maxProperties); + } + + #[Test] + public function invalidMaxPropertiesLessThanMinProperties(): void + { + self::expectException(InvalidArgumentException::class); + new With(minProperties: 5, maxProperties: 1); + } + + #[Test] + public function validDependentRequired(): void + { + $toolParameter = new With(dependentRequired: true); + self::assertTrue($toolParameter->dependentRequired); + } + + #[Test] + public function validCombination(): void + { + $toolParameter = new With( + enum: ['value1', 'value2'], + const: 'constant', + pattern: '/^[a-z]+$/', + minLength: 5, + maxLength: 10, + minimum: 0, + maximum: 100, + multipleOf: 5, + exclusiveMinimum: 1, + exclusiveMaximum: 99, + minItems: 1, + maxItems: 10, + uniqueItems: true, + minContains: 1, + maxContains: 5, + required: true, + minProperties: 1, + maxProperties: 5, + dependentRequired: true + ); + + self::assertInstanceOf(With::class, $toolParameter); + } + + #[Test] + public function invalidCombination(): void + { + self::expectException(InvalidArgumentException::class); + new With(minLength: -1, maxLength: -2); + } +} diff --git a/src/platform/tests/Contract/JsonSchema/DescriptionParserTest.php b/src/platform/tests/Contract/JsonSchema/DescriptionParserTest.php new file mode 100644 index 000000000..fb61ee2ed --- /dev/null +++ b/src/platform/tests/Contract/JsonSchema/DescriptionParserTest.php @@ -0,0 +1,133 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Contract\JsonSchema; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\DataProvider; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Agent\Tests\Fixture\StructuredOutput\User; +use Symfony\AI\Agent\Tests\Fixture\StructuredOutput\UserWithConstructor; +use Symfony\AI\Agent\Tests\Fixture\Tool\ToolRequiredParams; +use Symfony\AI\Agent\Tests\Fixture\Tool\ToolWithoutDocs; +use Symfony\AI\Platform\Contract\JsonSchema\DescriptionParser; + +#[CoversClass(DescriptionParser::class)] +final class DescriptionParserTest extends TestCase +{ + #[Test] + public function fromPropertyWithoutDocBlock(): void + { + $property = new \ReflectionProperty(User::class, 'id'); + + $actual = (new DescriptionParser())->getDescription($property); + + self::assertSame('', $actual); + } + + #[Test] + public function fromPropertyWithDocBlock(): void + { + $property = new \ReflectionProperty(User::class, 'name'); + + $actual = (new DescriptionParser())->getDescription($property); + + self::assertSame('The name of the user in lowercase', $actual); + } + + #[Test] + public function fromPropertyWithConstructorDocBlock(): void + { + $property = new \ReflectionProperty(UserWithConstructor::class, 'name'); + + $actual = (new DescriptionParser())->getDescription($property); + + self::assertSame('The name of the user in lowercase', $actual); + } + + #[Test] + public function fromParameterWithoutDocBlock(): void + { + $parameter = new \ReflectionParameter([ToolWithoutDocs::class, 'bar'], 'text'); + + $actual = (new DescriptionParser())->getDescription($parameter); + + self::assertSame('', $actual); + } + + #[Test] + public function fromParameterWithDocBlock(): void + { + $parameter = new \ReflectionParameter([ToolRequiredParams::class, 'bar'], 'text'); + + $actual = (new DescriptionParser())->getDescription($parameter); + + self::assertSame('The text given to the tool', $actual); + } + + #[Test] + #[DataProvider('provideMethodDescriptionCases')] + public function fromParameterWithDocs(string $comment, string $expected): void + { + $method = self::createMock(\ReflectionMethod::class); + $method->method('getDocComment')->willReturn($comment); + $parameter = self::createMock(\ReflectionParameter::class); + $parameter->method('getDeclaringFunction')->willReturn($method); + $parameter->method('getName')->willReturn('myParam'); + + $actual = (new DescriptionParser())->getDescription($parameter); + + self::assertSame($expected, $actual); + } + + public static function provideMethodDescriptionCases(): \Generator + { + yield 'empty doc block' => [ + 'comment' => '', + 'expected' => '', + ]; + + yield 'single line doc block with description' => [ + 'comment' => '/** @param string $myParam The description */', + 'expected' => 'The description', + ]; + + yield 'multi line doc block with description and other tags' => [ + 'comment' => <<<'TEXT' + /** + * @param string $myParam The description + * @return void + */ + TEXT, + 'expected' => 'The description', + ]; + + yield 'multi line doc block with multiple parameters' => [ + 'comment' => <<<'TEXT' + /** + * @param string $myParam The description + * @param string $anotherParam The wrong description + */ + TEXT, + 'expected' => 'The description', + ]; + + yield 'multi line doc block with parameter that is not searched for' => [ + 'comment' => <<<'TEXT' + /** + * @param string $unknownParam The description + */ + TEXT, + 'expected' => '', + ]; + } +} diff --git a/src/platform/tests/Contract/JsonSchema/FactoryTest.php b/src/platform/tests/Contract/JsonSchema/FactoryTest.php new file mode 100644 index 000000000..d86621fbd --- /dev/null +++ b/src/platform/tests/Contract/JsonSchema/FactoryTest.php @@ -0,0 +1,251 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Contract\JsonSchema; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Agent\Tests\Fixture\StructuredOutput\MathReasoning; +use Symfony\AI\Agent\Tests\Fixture\StructuredOutput\Step; +use Symfony\AI\Agent\Tests\Fixture\StructuredOutput\User; +use Symfony\AI\Agent\Tests\Fixture\Tool\ToolNoParams; +use Symfony\AI\Agent\Tests\Fixture\Tool\ToolOptionalParam; +use Symfony\AI\Agent\Tests\Fixture\Tool\ToolRequiredParams; +use Symfony\AI\Agent\Tests\Fixture\Tool\ToolWithToolParameterAttribute; +use Symfony\AI\Platform\Contract\JsonSchema\Attribute\With; +use Symfony\AI\Platform\Contract\JsonSchema\DescriptionParser; +use Symfony\AI\Platform\Contract\JsonSchema\Factory; + +#[CoversClass(Factory::class)] +#[UsesClass(With::class)] +#[UsesClass(DescriptionParser::class)] +final class FactoryTest extends TestCase +{ + private Factory $factory; + + protected function setUp(): void + { + $this->factory = new Factory(); + } + + protected function tearDown(): void + { + unset($this->factory); + } + + #[Test] + public function buildParametersDefinitionRequired(): void + { + $actual = $this->factory->buildParameters(ToolRequiredParams::class, 'bar'); + $expected = [ + 'type' => 'object', + 'properties' => [ + 'text' => [ + 'type' => 'string', + 'description' => 'The text given to the tool', + ], + 'number' => [ + 'type' => 'integer', + 'description' => 'A number given to the tool', + ], + ], + 'required' => ['text', 'number'], + 'additionalProperties' => false, + ]; + + self::assertSame($expected, $actual); + } + + #[Test] + public function buildParametersDefinitionRequiredWithAdditionalToolParameterAttribute(): void + { + $actual = $this->factory->buildParameters(ToolWithToolParameterAttribute::class, '__invoke'); + $expected = [ + 'type' => 'object', + 'properties' => [ + 'animal' => [ + 'type' => 'string', + 'description' => 'The animal given to the tool', + 'enum' => ['dog', 'cat', 'bird'], + ], + 'numberOfArticles' => [ + 'type' => 'integer', + 'description' => 'The number of articles given to the tool', + 'const' => 42, + ], + 'infoEmail' => [ + 'type' => 'string', + 'description' => 'The info email given to the tool', + 'const' => 'info@example.de', + ], + 'locales' => [ + 'type' => 'string', + 'description' => 'The locales given to the tool', + 'const' => ['de', 'en'], + ], + 'text' => [ + 'type' => 'string', + 'description' => 'The text given to the tool', + 'pattern' => '^[a-zA-Z]+$', + 'minLength' => 1, + 'maxLength' => 10, + ], + 'number' => [ + 'type' => 'integer', + 'description' => 'The number given to the tool', + 'minimum' => 1, + 'maximum' => 10, + 'multipleOf' => 2, + 'exclusiveMinimum' => 1, + 'exclusiveMaximum' => 10, + ], + 'products' => [ + 'type' => 'array', + 'items' => ['type' => 'string'], + 'description' => 'The products given to the tool', + 'minItems' => 1, + 'maxItems' => 10, + 'uniqueItems' => true, + 'minContains' => 1, + 'maxContains' => 10, + ], + 'shippingAddress' => [ + 'type' => 'string', + 'description' => 'The shipping address given to the tool', + 'required' => true, + 'minProperties' => 1, + 'maxProperties' => 10, + 'dependentRequired' => true, + ], + ], + 'required' => [ + 'animal', + 'numberOfArticles', + 'infoEmail', + 'locales', + 'text', + 'number', + 'products', + 'shippingAddress', + ], + 'additionalProperties' => false, + ]; + + self::assertSame($expected, $actual); + } + + #[Test] + public function buildParametersDefinitionOptional(): void + { + $actual = $this->factory->buildParameters(ToolOptionalParam::class, 'bar'); + $expected = [ + 'type' => 'object', + 'properties' => [ + 'text' => [ + 'type' => 'string', + 'description' => 'The text given to the tool', + ], + 'number' => [ + 'type' => 'integer', + 'description' => 'A number given to the tool', + ], + ], + 'required' => ['text'], + 'additionalProperties' => false, + ]; + + self::assertSame($expected, $actual); + } + + #[Test] + public function buildParametersDefinitionNone(): void + { + $actual = $this->factory->buildParameters(ToolNoParams::class, '__invoke'); + + self::assertNull($actual); + } + + #[Test] + public function buildPropertiesForUserClass(): void + { + $expected = [ + 'type' => 'object', + 'properties' => [ + 'id' => ['type' => 'integer'], + 'name' => [ + 'type' => 'string', + 'description' => 'The name of the user in lowercase', + ], + 'createdAt' => [ + 'type' => 'string', + 'format' => 'date-time', + ], + 'isActive' => ['type' => 'boolean'], + 'age' => ['type' => ['integer', 'null']], + ], + 'required' => ['id', 'name', 'createdAt', 'isActive'], + 'additionalProperties' => false, + ]; + + $actual = $this->factory->buildProperties(User::class); + + self::assertSame($expected, $actual); + } + + #[Test] + public function buildPropertiesForMathReasoningClass(): void + { + $expected = [ + 'type' => 'object', + 'properties' => [ + 'steps' => [ + 'type' => 'array', + 'items' => [ + 'type' => 'object', + 'properties' => [ + 'explanation' => ['type' => 'string'], + 'output' => ['type' => 'string'], + ], + 'required' => ['explanation', 'output'], + 'additionalProperties' => false, + ], + ], + 'finalAnswer' => ['type' => 'string'], + ], + 'required' => ['steps', 'finalAnswer'], + 'additionalProperties' => false, + ]; + + $actual = $this->factory->buildProperties(MathReasoning::class); + + self::assertSame($expected, $actual); + } + + #[Test] + public function buildPropertiesForStepClass(): void + { + $expected = [ + 'type' => 'object', + 'properties' => [ + 'explanation' => ['type' => 'string'], + 'output' => ['type' => 'string'], + ], + 'required' => ['explanation', 'output'], + 'additionalProperties' => false, + ]; + + $actual = $this->factory->buildProperties(Step::class); + + self::assertSame($expected, $actual); + } +} diff --git a/src/platform/tests/Contract/Normalizer/Message/AssistantMessageNormalizerTest.php b/src/platform/tests/Contract/Normalizer/Message/AssistantMessageNormalizerTest.php new file mode 100644 index 000000000..226302676 --- /dev/null +++ b/src/platform/tests/Contract/Normalizer/Message/AssistantMessageNormalizerTest.php @@ -0,0 +1,115 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Contract\Normalizer\Message; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Contract\Normalizer\Message\AssistantMessageNormalizer; +use Symfony\AI\Platform\Message\AssistantMessage; +use Symfony\AI\Platform\Response\ToolCall; +use Symfony\Component\Serializer\Normalizer\NormalizerInterface; + +#[CoversClass(AssistantMessageNormalizer::class)] +#[UsesClass(AssistantMessage::class)] +#[UsesClass(ToolCall::class)] +final class AssistantMessageNormalizerTest extends TestCase +{ + private AssistantMessageNormalizer $normalizer; + + protected function setUp(): void + { + $this->normalizer = new AssistantMessageNormalizer(); + } + + #[Test] + public function supportsNormalization(): void + { + self::assertTrue($this->normalizer->supportsNormalization(new AssistantMessage('content'))); + self::assertFalse($this->normalizer->supportsNormalization(new \stdClass())); + } + + #[Test] + public function getSupportedTypes(): void + { + self::assertSame([AssistantMessage::class => true], $this->normalizer->getSupportedTypes(null)); + } + + #[Test] + public function normalizeWithContent(): void + { + $message = new AssistantMessage('I am an assistant'); + + $expected = [ + 'role' => 'assistant', + 'content' => 'I am an assistant', + ]; + + self::assertSame($expected, $this->normalizer->normalize($message)); + } + + #[Test] + public function normalizeWithToolCalls(): void + { + $toolCalls = [ + new ToolCall('id1', 'function1', ['param' => 'value']), + new ToolCall('id2', 'function2', ['param' => 'value2']), + ]; + $message = new AssistantMessage('Content with tools', $toolCalls); + + $expectedToolCalls = [ + ['id' => 'id1', 'function' => 'function1', 'arguments' => ['param' => 'value']], + ['id' => 'id2', 'function' => 'function2', 'arguments' => ['param' => 'value2']], + ]; + + $innerNormalizer = $this->createMock(NormalizerInterface::class); + $innerNormalizer->expects(self::once()) + ->method('normalize') + ->with($message->toolCalls, null, []) + ->willReturn($expectedToolCalls); + + $this->normalizer->setNormalizer($innerNormalizer); + + $expected = [ + 'role' => 'assistant', + 'content' => 'Content with tools', + 'tool_calls' => $expectedToolCalls, + ]; + + self::assertSame($expected, $this->normalizer->normalize($message)); + } + + #[Test] + public function normalizeWithNullContent(): void + { + $toolCalls = [new ToolCall('id1', 'function1', ['param' => 'value'])]; + $message = new AssistantMessage(null, $toolCalls); + + $expectedToolCalls = [['id' => 'id1', 'function' => 'function1', 'arguments' => ['param' => 'value']]]; + + $innerNormalizer = $this->createMock(NormalizerInterface::class); + $innerNormalizer->expects(self::once()) + ->method('normalize') + ->with($message->toolCalls, null, []) + ->willReturn($expectedToolCalls); + + $this->normalizer->setNormalizer($innerNormalizer); + + $expected = [ + 'role' => 'assistant', + 'tool_calls' => $expectedToolCalls, + ]; + + self::assertSame($expected, $this->normalizer->normalize($message)); + } +} diff --git a/src/platform/tests/Contract/Normalizer/Message/Content/AudioNormalizerTest.php b/src/platform/tests/Contract/Normalizer/Message/Content/AudioNormalizerTest.php new file mode 100644 index 000000000..35c60630f --- /dev/null +++ b/src/platform/tests/Contract/Normalizer/Message/Content/AudioNormalizerTest.php @@ -0,0 +1,83 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Contract\Normalizer\Message\Content; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\DataProvider; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Contract\Normalizer\Message\Content\AudioNormalizer; +use Symfony\AI\Platform\Message\Content\Audio; +use Symfony\AI\Platform\Message\Content\File; + +#[CoversClass(AudioNormalizer::class)] +#[UsesClass(Audio::class)] +#[UsesClass(File::class)] +final class AudioNormalizerTest extends TestCase +{ + private AudioNormalizer $normalizer; + + protected function setUp(): void + { + $this->normalizer = new AudioNormalizer(); + } + + #[Test] + public function supportsNormalization(): void + { + self::assertTrue($this->normalizer->supportsNormalization(Audio::fromFile(\dirname(__DIR__, 5).'/Fixture/audio.mp3'))); + self::assertFalse($this->normalizer->supportsNormalization(new \stdClass())); + } + + #[Test] + public function getSupportedTypes(): void + { + self::assertSame([Audio::class => true], $this->normalizer->getSupportedTypes(null)); + } + + #[Test] + #[DataProvider('provideAudioData')] + public function normalize(string $data, string $format, array $expected): void + { + $audio = new Audio(base64_decode($data), $format); + + self::assertSame($expected, $this->normalizer->normalize($audio)); + } + + public static function provideAudioData(): \Generator + { + yield 'mp3 data' => [ + 'SUQzBAAAAAAAfVREUkMAAAAMAAADMg==', + 'audio/mpeg', + [ + 'type' => 'input_audio', + 'input_audio' => [ + 'data' => 'SUQzBAAAAAAAfVREUkMAAAAMAAADMg==', + 'format' => 'mp3', + ], + ], + ]; + + yield 'wav data' => [ + 'UklGRiQAAABXQVZFZm10IBA=', + 'audio/wav', + [ + 'type' => 'input_audio', + 'input_audio' => [ + 'data' => 'UklGRiQAAABXQVZFZm10IBA=', + 'format' => 'wav', + ], + ], + ]; + } +} diff --git a/src/platform/tests/Contract/Normalizer/Message/Content/ImageNormalizerTest.php b/src/platform/tests/Contract/Normalizer/Message/Content/ImageNormalizerTest.php new file mode 100644 index 000000000..5bbcdd4e0 --- /dev/null +++ b/src/platform/tests/Contract/Normalizer/Message/Content/ImageNormalizerTest.php @@ -0,0 +1,59 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Contract\Normalizer\Message\Content; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Contract\Normalizer\Message\Content\ImageNormalizer; +use Symfony\AI\Platform\Message\Content\File; +use Symfony\AI\Platform\Message\Content\Image; + +#[CoversClass(ImageNormalizer::class)] +#[UsesClass(Image::class)] +#[UsesClass(File::class)] +final class ImageNormalizerTest extends TestCase +{ + private ImageNormalizer $normalizer; + + protected function setUp(): void + { + $this->normalizer = new ImageNormalizer(); + } + + #[Test] + public function supportsNormalization(): void + { + self::assertTrue($this->normalizer->supportsNormalization(Image::fromFile(\dirname(__DIR__, 5).'/Fixture/image.jpg'))); + self::assertFalse($this->normalizer->supportsNormalization(new \stdClass())); + } + + #[Test] + public function getSupportedTypes(): void + { + self::assertSame([Image::class => true], $this->normalizer->getSupportedTypes(null)); + } + + #[Test] + public function normalize(): void + { + $image = Image::fromDataUrl(''); + + $expected = [ + 'type' => 'image_url', + 'image_url' => ['url' => ''], + ]; + + self::assertSame($expected, $this->normalizer->normalize($image)); + } +} diff --git a/src/platform/tests/Contract/Normalizer/Message/Content/ImageUrlNormalizerTest.php b/src/platform/tests/Contract/Normalizer/Message/Content/ImageUrlNormalizerTest.php new file mode 100644 index 000000000..9e2d29b2f --- /dev/null +++ b/src/platform/tests/Contract/Normalizer/Message/Content/ImageUrlNormalizerTest.php @@ -0,0 +1,57 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Contract\Normalizer\Message\Content; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Contract\Normalizer\Message\Content\ImageUrlNormalizer; +use Symfony\AI\Platform\Message\Content\ImageUrl; + +#[CoversClass(ImageUrlNormalizer::class)] +#[UsesClass(ImageUrl::class)] +final class ImageUrlNormalizerTest extends TestCase +{ + private ImageUrlNormalizer $normalizer; + + protected function setUp(): void + { + $this->normalizer = new ImageUrlNormalizer(); + } + + #[Test] + public function supportsNormalization(): void + { + self::assertTrue($this->normalizer->supportsNormalization(new ImageUrl('https://example.com/image.jpg'))); + self::assertFalse($this->normalizer->supportsNormalization(new \stdClass())); + } + + #[Test] + public function getSupportedTypes(): void + { + self::assertSame([ImageUrl::class => true], $this->normalizer->getSupportedTypes(null)); + } + + #[Test] + public function normalize(): void + { + $imageUrl = new ImageUrl('https://example.com/image.jpg'); + + $expected = [ + 'type' => 'image_url', + 'image_url' => ['url' => 'https://example.com/image.jpg'], + ]; + + self::assertSame($expected, $this->normalizer->normalize($imageUrl)); + } +} diff --git a/src/platform/tests/Contract/Normalizer/Message/Content/TextNormalizerTest.php b/src/platform/tests/Contract/Normalizer/Message/Content/TextNormalizerTest.php new file mode 100644 index 000000000..3859ddf72 --- /dev/null +++ b/src/platform/tests/Contract/Normalizer/Message/Content/TextNormalizerTest.php @@ -0,0 +1,57 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Contract\Normalizer\Message\Content; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Contract\Normalizer\Message\Content\TextNormalizer; +use Symfony\AI\Platform\Message\Content\Text; + +#[CoversClass(TextNormalizer::class)] +#[UsesClass(Text::class)] +final class TextNormalizerTest extends TestCase +{ + private TextNormalizer $normalizer; + + protected function setUp(): void + { + $this->normalizer = new TextNormalizer(); + } + + #[Test] + public function supportsNormalization(): void + { + self::assertTrue($this->normalizer->supportsNormalization(new Text('Hello, world!'))); + self::assertFalse($this->normalizer->supportsNormalization(new \stdClass())); + } + + #[Test] + public function getSupportedTypes(): void + { + self::assertSame([Text::class => true], $this->normalizer->getSupportedTypes(null)); + } + + #[Test] + public function normalize(): void + { + $text = new Text('Hello, world!'); + + $expected = [ + 'type' => 'text', + 'text' => 'Hello, world!', + ]; + + self::assertSame($expected, $this->normalizer->normalize($text)); + } +} diff --git a/src/platform/tests/Contract/Normalizer/Message/MessageBagNormalizerTest.php b/src/platform/tests/Contract/Normalizer/Message/MessageBagNormalizerTest.php new file mode 100644 index 000000000..ad79a7d18 --- /dev/null +++ b/src/platform/tests/Contract/Normalizer/Message/MessageBagNormalizerTest.php @@ -0,0 +1,124 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Contract\Normalizer\Message; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Bridge\OpenAI\GPT; +use Symfony\AI\Platform\Contract; +use Symfony\AI\Platform\Contract\Normalizer\Message\MessageBagNormalizer; +use Symfony\AI\Platform\Message\Content\Text; +use Symfony\AI\Platform\Message\MessageBag; +use Symfony\AI\Platform\Message\MessageBagInterface; +use Symfony\AI\Platform\Message\SystemMessage; +use Symfony\AI\Platform\Message\UserMessage; +use Symfony\AI\Platform\Model; +use Symfony\Component\Serializer\Normalizer\NormalizerInterface; + +#[CoversClass(MessageBagNormalizer::class)] +#[UsesClass(MessageBag::class)] +#[UsesClass(SystemMessage::class)] +#[UsesClass(UserMessage::class)] +#[UsesClass(Text::class)] +#[UsesClass(GPT::class)] +#[UsesClass(Model::class)] +final class MessageBagNormalizerTest extends TestCase +{ + private MessageBagNormalizer $normalizer; + + protected function setUp(): void + { + $this->normalizer = new MessageBagNormalizer(); + } + + #[Test] + public function supportsNormalization(): void + { + $messageBag = $this->createMock(MessageBagInterface::class); + + self::assertTrue($this->normalizer->supportsNormalization($messageBag)); + self::assertFalse($this->normalizer->supportsNormalization(new \stdClass())); + } + + #[Test] + public function getSupportedTypes(): void + { + self::assertSame([MessageBagInterface::class => true], $this->normalizer->getSupportedTypes(null)); + } + + #[Test] + public function normalizeWithoutModel(): void + { + $messages = [ + new SystemMessage('You are a helpful assistant'), + new UserMessage(new Text('Hello')), + ]; + + $messageBag = new MessageBag(...$messages); + + $innerNormalizer = $this->createMock(NormalizerInterface::class); + $innerNormalizer->expects(self::once()) + ->method('normalize') + ->with($messages, null, []) + ->willReturn([ + ['role' => 'system', 'content' => 'You are a helpful assistant'], + ['role' => 'user', 'content' => 'Hello'], + ]); + + $this->normalizer->setNormalizer($innerNormalizer); + + $expected = [ + 'messages' => [ + ['role' => 'system', 'content' => 'You are a helpful assistant'], + ['role' => 'user', 'content' => 'Hello'], + ], + ]; + + self::assertSame($expected, $this->normalizer->normalize($messageBag)); + } + + #[Test] + public function normalizeWithModel(): void + { + $messages = [ + new SystemMessage('You are a helpful assistant'), + new UserMessage(new Text('Hello')), + ]; + + $messageBag = new MessageBag(...$messages); + + $innerNormalizer = $this->createMock(NormalizerInterface::class); + $innerNormalizer->expects(self::once()) + ->method('normalize') + ->with($messages, null, [Contract::CONTEXT_MODEL => new GPT()]) + ->willReturn([ + ['role' => 'system', 'content' => 'You are a helpful assistant'], + ['role' => 'user', 'content' => 'Hello'], + ]); + + $this->normalizer->setNormalizer($innerNormalizer); + + $expected = [ + 'messages' => [ + ['role' => 'system', 'content' => 'You are a helpful assistant'], + ['role' => 'user', 'content' => 'Hello'], + ], + 'model' => 'gpt-4o', + ]; + + self::assertSame($expected, $this->normalizer->normalize($messageBag, context: [ + Contract::CONTEXT_MODEL => new GPT(), + ])); + } +} diff --git a/src/platform/tests/Contract/Normalizer/Message/SystemMessageNormalizerTest.php b/src/platform/tests/Contract/Normalizer/Message/SystemMessageNormalizerTest.php new file mode 100644 index 000000000..a83e37166 --- /dev/null +++ b/src/platform/tests/Contract/Normalizer/Message/SystemMessageNormalizerTest.php @@ -0,0 +1,57 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Contract\Normalizer\Message; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Contract\Normalizer\Message\SystemMessageNormalizer; +use Symfony\AI\Platform\Message\SystemMessage; + +#[CoversClass(SystemMessageNormalizer::class)] +#[UsesClass(SystemMessage::class)] +final class SystemMessageNormalizerTest extends TestCase +{ + private SystemMessageNormalizer $normalizer; + + protected function setUp(): void + { + $this->normalizer = new SystemMessageNormalizer(); + } + + #[Test] + public function supportsNormalization(): void + { + self::assertTrue($this->normalizer->supportsNormalization(new SystemMessage('content'))); + self::assertFalse($this->normalizer->supportsNormalization(new \stdClass())); + } + + #[Test] + public function getSupportedTypes(): void + { + self::assertSame([SystemMessage::class => true], $this->normalizer->getSupportedTypes(null)); + } + + #[Test] + public function normalize(): void + { + $message = new SystemMessage('You are a helpful assistant'); + + $expected = [ + 'role' => 'system', + 'content' => 'You are a helpful assistant', + ]; + + self::assertSame($expected, $this->normalizer->normalize($message)); + } +} diff --git a/src/platform/tests/Contract/Normalizer/Message/ToolCallMessageNormalizerTest.php b/src/platform/tests/Contract/Normalizer/Message/ToolCallMessageNormalizerTest.php new file mode 100644 index 000000000..b14b0add6 --- /dev/null +++ b/src/platform/tests/Contract/Normalizer/Message/ToolCallMessageNormalizerTest.php @@ -0,0 +1,73 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Contract\Normalizer\Message; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Contract\Normalizer\Message\ToolCallMessageNormalizer; +use Symfony\AI\Platform\Message\ToolCallMessage; +use Symfony\AI\Platform\Response\ToolCall; +use Symfony\Component\Serializer\Normalizer\NormalizerInterface; + +#[CoversClass(ToolCallMessageNormalizer::class)] +#[UsesClass(ToolCallMessage::class)] +#[UsesClass(ToolCall::class)] +final class ToolCallMessageNormalizerTest extends TestCase +{ + private ToolCallMessageNormalizer $normalizer; + + protected function setUp(): void + { + $this->normalizer = new ToolCallMessageNormalizer(); + } + + #[Test] + public function supportsNormalization(): void + { + $toolCallMessage = new ToolCallMessage(new ToolCall('id', 'function'), 'content'); + + self::assertTrue($this->normalizer->supportsNormalization($toolCallMessage)); + self::assertFalse($this->normalizer->supportsNormalization(new \stdClass())); + } + + #[Test] + public function getSupportedTypes(): void + { + self::assertSame([ToolCallMessage::class => true], $this->normalizer->getSupportedTypes(null)); + } + + #[Test] + public function normalize(): void + { + $toolCall = new ToolCall('tool_call_123', 'get_weather', ['location' => 'Paris']); + $message = new ToolCallMessage($toolCall, 'Weather data for Paris'); + $expectedContent = 'Normalized weather data for Paris'; + + $innerNormalizer = $this->createMock(NormalizerInterface::class); + $innerNormalizer->expects(self::once()) + ->method('normalize') + ->with($message->content, null, []) + ->willReturn($expectedContent); + + $this->normalizer->setNormalizer($innerNormalizer); + + $expected = [ + 'role' => 'tool', + 'content' => $expectedContent, + 'tool_call_id' => 'tool_call_123', + ]; + + self::assertSame($expected, $this->normalizer->normalize($message)); + } +} diff --git a/src/platform/tests/Contract/Normalizer/Message/UserMessageNormalizerTest.php b/src/platform/tests/Contract/Normalizer/Message/UserMessageNormalizerTest.php new file mode 100644 index 000000000..6eda0b0bb --- /dev/null +++ b/src/platform/tests/Contract/Normalizer/Message/UserMessageNormalizerTest.php @@ -0,0 +1,91 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Contract\Normalizer\Message; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Contract\Normalizer\Message\UserMessageNormalizer; +use Symfony\AI\Platform\Message\Content\ImageUrl; +use Symfony\AI\Platform\Message\Content\Text; +use Symfony\AI\Platform\Message\UserMessage; +use Symfony\Component\Serializer\Normalizer\NormalizerInterface; + +#[CoversClass(UserMessageNormalizer::class)] +#[UsesClass(UserMessage::class)] +#[UsesClass(Text::class)] +#[UsesClass(ImageUrl::class)] +final class UserMessageNormalizerTest extends TestCase +{ + private UserMessageNormalizer $normalizer; + + protected function setUp(): void + { + $this->normalizer = new UserMessageNormalizer(); + } + + #[Test] + public function supportsNormalization(): void + { + self::assertTrue($this->normalizer->supportsNormalization(new UserMessage(new Text('content')))); + self::assertFalse($this->normalizer->supportsNormalization(new \stdClass())); + } + + #[Test] + public function getSupportedTypes(): void + { + self::assertSame([UserMessage::class => true], $this->normalizer->getSupportedTypes(null)); + } + + #[Test] + public function normalizeWithSingleTextContent(): void + { + $textContent = new Text('Hello, how can you help me?'); + $message = new UserMessage($textContent); + + $expected = [ + 'role' => 'user', + 'content' => 'Hello, how can you help me?', + ]; + + self::assertSame($expected, $this->normalizer->normalize($message)); + } + + #[Test] + public function normalizeWithMixedContent(): void + { + $textContent = new Text('Please describe this image:'); + $imageContent = new ImageUrl('https://example.com/image.jpg'); + $message = new UserMessage($textContent, $imageContent); + + $expectedContent = [ + ['type' => 'text', 'text' => 'Please describe this image:'], + ['type' => 'image', 'url' => 'https://example.com/image.jpg'], + ]; + + $innerNormalizer = $this->createMock(NormalizerInterface::class); + $innerNormalizer->expects(self::once()) + ->method('normalize') + ->with($message->content, null, []) + ->willReturn($expectedContent); + + $this->normalizer->setNormalizer($innerNormalizer); + + $expected = [ + 'role' => 'user', + 'content' => $expectedContent, + ]; + + self::assertSame($expected, $this->normalizer->normalize($message)); + } +} diff --git a/src/platform/tests/Contract/Normalizer/ToolNormalizerTest.php b/src/platform/tests/Contract/Normalizer/ToolNormalizerTest.php new file mode 100644 index 000000000..b08192c42 --- /dev/null +++ b/src/platform/tests/Contract/Normalizer/ToolNormalizerTest.php @@ -0,0 +1,160 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Contract\Normalizer; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\DataProvider; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Agent\Tests\Fixture\Tool\ToolException; +use Symfony\AI\Agent\Tests\Fixture\Tool\ToolNoParams; +use Symfony\AI\Agent\Tests\Fixture\Tool\ToolOptionalParam; +use Symfony\AI\Agent\Tests\Fixture\Tool\ToolRequiredParams; +use Symfony\AI\Platform\Contract\Normalizer\ToolNormalizer; +use Symfony\AI\Platform\Tool\ExecutionReference; +use Symfony\AI\Platform\Tool\Tool; + +#[CoversClass(ToolNormalizer::class)] +#[Small] +class ToolNormalizerTest extends TestCase +{ + #[Test] + #[DataProvider('provideTools')] + public function normalize(Tool $tool, array $expected): void + { + self::assertSame($expected, (new ToolNormalizer())->normalize($tool)); + } + + public static function provideTools(): \Generator + { + yield 'required params' => [ + new Tool( + new ExecutionReference(ToolRequiredParams::class, 'bar'), + 'tool_required_params', + 'A tool with required parameters', + [ + 'type' => 'object', + 'properties' => [ + 'text' => [ + 'type' => 'string', + 'description' => 'The text given to the tool', + ], + 'number' => [ + 'type' => 'integer', + 'description' => 'A number given to the tool', + ], + ], + 'required' => ['text', 'number'], + 'additionalProperties' => false, + ], + ), + [ + 'type' => 'function', + 'function' => [ + 'name' => 'tool_required_params', + 'description' => 'A tool with required parameters', + 'parameters' => [ + 'type' => 'object', + 'properties' => [ + 'text' => [ + 'type' => 'string', + 'description' => 'The text given to the tool', + ], + 'number' => [ + 'type' => 'integer', + 'description' => 'A number given to the tool', + ], + ], + 'required' => ['text', 'number'], + 'additionalProperties' => false, + ], + ], + ], + ]; + + yield 'optional param' => [ + new Tool( + new ExecutionReference(ToolOptionalParam::class, 'bar'), + 'tool_optional_param', + 'A tool with one optional parameter', + [ + 'type' => 'object', + 'properties' => [ + 'text' => [ + 'type' => 'string', + 'description' => 'The text given to the tool', + ], + 'number' => [ + 'type' => 'integer', + 'description' => 'A number given to the tool', + ], + ], + 'required' => ['text'], + 'additionalProperties' => false, + ], + ), + [ + 'type' => 'function', + 'function' => [ + 'name' => 'tool_optional_param', + 'description' => 'A tool with one optional parameter', + 'parameters' => [ + 'type' => 'object', + 'properties' => [ + 'text' => [ + 'type' => 'string', + 'description' => 'The text given to the tool', + ], + 'number' => [ + 'type' => 'integer', + 'description' => 'A number given to the tool', + ], + ], + 'required' => ['text'], + 'additionalProperties' => false, + ], + ], + ], + ]; + + yield 'no params' => [ + new Tool( + new ExecutionReference(ToolNoParams::class), + 'tool_no_params', + 'A tool without parameters', + ), + [ + 'type' => 'function', + 'function' => [ + 'name' => 'tool_no_params', + 'description' => 'A tool without parameters', + ], + ], + ]; + + yield 'exception' => [ + new Tool( + new ExecutionReference(ToolException::class, 'bar'), + 'tool_exception', + 'This tool is broken', + ), + [ + 'type' => 'function', + 'function' => [ + 'name' => 'tool_exception', + 'description' => 'This tool is broken', + ], + ], + ]; + } +} diff --git a/src/platform/tests/ContractTest.php b/src/platform/tests/ContractTest.php new file mode 100644 index 000000000..4d237e0db --- /dev/null +++ b/src/platform/tests/ContractTest.php @@ -0,0 +1,217 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\DataProvider; +use PHPUnit\Framework\Attributes\Large; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Bridge\OpenAI\Embeddings; +use Symfony\AI\Platform\Bridge\OpenAI\GPT; +use Symfony\AI\Platform\Bridge\OpenAI\Whisper; +use Symfony\AI\Platform\Bridge\OpenAI\Whisper\AudioNormalizer; +use Symfony\AI\Platform\Contract; +use Symfony\AI\Platform\Contract\Normalizer\Message\AssistantMessageNormalizer; +use Symfony\AI\Platform\Contract\Normalizer\Message\Content\ImageNormalizer; +use Symfony\AI\Platform\Contract\Normalizer\Message\Content\ImageUrlNormalizer; +use Symfony\AI\Platform\Contract\Normalizer\Message\Content\TextNormalizer; +use Symfony\AI\Platform\Contract\Normalizer\Message\MessageBagNormalizer; +use Symfony\AI\Platform\Contract\Normalizer\Message\SystemMessageNormalizer; +use Symfony\AI\Platform\Contract\Normalizer\Message\ToolCallMessageNormalizer; +use Symfony\AI\Platform\Contract\Normalizer\Message\UserMessageNormalizer; +use Symfony\AI\Platform\Contract\Normalizer\Response\ToolCallNormalizer; +use Symfony\AI\Platform\Message\AssistantMessage; +use Symfony\AI\Platform\Message\Content\Audio; +use Symfony\AI\Platform\Message\Content\Image; +use Symfony\AI\Platform\Message\Content\ImageUrl; +use Symfony\AI\Platform\Message\Message; +use Symfony\AI\Platform\Message\MessageBag; +use Symfony\AI\Platform\Message\SystemMessage; +use Symfony\AI\Platform\Message\UserMessage; +use Symfony\AI\Platform\Model; + +#[Large] +#[CoversClass(Contract::class)] +#[CoversClass(AssistantMessageNormalizer::class)] +#[CoversClass(AudioNormalizer::class)] +#[CoversClass(ImageNormalizer::class)] +#[CoversClass(ImageUrlNormalizer::class)] +#[CoversClass(TextNormalizer::class)] +#[CoversClass(MessageBagNormalizer::class)] +#[CoversClass(SystemMessageNormalizer::class)] +#[CoversClass(ToolCallMessageNormalizer::class)] +#[CoversClass(UserMessageNormalizer::class)] +#[CoversClass(ToolCallNormalizer::class)] +#[UsesClass(AssistantMessage::class)] +#[UsesClass(MessageBag::class)] +#[UsesClass(SystemMessage::class)] +#[UsesClass(UserMessage::class)] +#[UsesClass(Model::class)] +final class ContractTest extends TestCase +{ + #[Test] + #[DataProvider('providePayloadTestCases')] + public function createRequestPayload(Model $model, array|string|object $input, array|string $expected): void + { + $contract = Contract::create(); + + $actual = $contract->createRequestPayload($model, $input); + + self::assertSame($expected, $actual); + } + + /** + * @return iterable|string + * }> + */ + public static function providePayloadTestCases(): iterable + { + yield 'MessageBag with GPT' => [ + 'model' => new GPT(), + 'input' => new MessageBag( + Message::forSystem('System message'), + Message::ofUser('User message'), + Message::ofAssistant('Assistant message'), + ), + 'expected' => [ + 'messages' => [ + ['role' => 'system', 'content' => 'System message'], + ['role' => 'user', 'content' => 'User message'], + ['role' => 'assistant', 'content' => 'Assistant message'], + ], + 'model' => 'gpt-4o', + ], + ]; + + $audio = Audio::fromFile(\dirname(__DIR__, 2).'/tests/Fixture/audio.mp3'); + yield 'Audio within MessageBag with GPT' => [ + 'model' => new GPT(), + 'input' => new MessageBag(Message::ofUser('What is this recording about?', $audio)), + 'expected' => [ + 'messages' => [ + [ + 'role' => 'user', + 'content' => [ + ['type' => 'text', 'text' => 'What is this recording about?'], + [ + 'type' => 'input_audio', + 'input_audio' => [ + 'data' => $audio->asBase64(), + 'format' => 'mp3', + ], + ], + ], + ], + ], + 'model' => 'gpt-4o', + ], + ]; + + $image = Image::fromFile(\dirname(__DIR__, 2).'/tests/Fixture/image.jpg'); + yield 'Image within MessageBag with GPT' => [ + 'model' => new GPT(), + 'input' => new MessageBag( + Message::forSystem('You are an image analyzer bot that helps identify the content of images.'), + Message::ofUser('Describe the image as a comedian would do it.', $image), + ), + 'expected' => [ + 'messages' => [ + [ + 'role' => 'system', + 'content' => 'You are an image analyzer bot that helps identify the content of images.', + ], + [ + 'role' => 'user', + 'content' => [ + ['type' => 'text', 'text' => 'Describe the image as a comedian would do it.'], + ['type' => 'image_url', 'image_url' => ['url' => $image->asDataUrl()]], + ], + ], + ], + 'model' => 'gpt-4o', + ], + ]; + + yield 'ImageUrl within MessageBag with GPT' => [ + 'model' => new GPT(), + 'input' => new MessageBag( + Message::forSystem('You are an image analyzer bot that helps identify the content of images.'), + Message::ofUser('Describe the image as a comedian would do it.', new ImageUrl('https://example.com/image.jpg')), + ), + 'expected' => [ + 'messages' => [ + [ + 'role' => 'system', + 'content' => 'You are an image analyzer bot that helps identify the content of images.', + ], + [ + 'role' => 'user', + 'content' => [ + ['type' => 'text', 'text' => 'Describe the image as a comedian would do it.'], + ['type' => 'image_url', 'image_url' => ['url' => 'https://example.com/image.jpg']], + ], + ], + ], + 'model' => 'gpt-4o', + ], + ]; + + yield 'Text Input with Embeddings' => [ + 'model' => new Embeddings(), + 'input' => 'This is a test input.', + 'expected' => 'This is a test input.', + ]; + + yield 'Longer Conversation with GPT' => [ + 'model' => new GPT(), + 'input' => new MessageBag( + Message::forSystem('My amazing system prompt.'), + Message::ofAssistant('It is time to sleep.'), + Message::ofUser('Hello, world!'), + new AssistantMessage('Hello User!'), + Message::ofUser('My hint for how to analyze an image.', new ImageUrl('http://image-generator.local/my-fancy-image.png')), + ), + 'expected' => [ + 'messages' => [ + ['role' => 'system', 'content' => 'My amazing system prompt.'], + ['role' => 'assistant', 'content' => 'It is time to sleep.'], + ['role' => 'user', 'content' => 'Hello, world!'], + ['role' => 'assistant', 'content' => 'Hello User!'], + ['role' => 'user', 'content' => [ + ['type' => 'text', 'text' => 'My hint for how to analyze an image.'], + ['type' => 'image_url', 'image_url' => ['url' => 'http://image-generator.local/my-fancy-image.png']], + ]], + ], + 'model' => 'gpt-4o', + ], + ]; + } + + #[Test] + public function extendedContractHandlesWhisper(): void + { + $contract = Contract::create(new AudioNormalizer()); + + $audio = Audio::fromFile(\dirname(__DIR__, 2).'/tests/Fixture/audio.mp3'); + + $actual = $contract->createRequestPayload(new Whisper(), $audio); + + self::assertArrayHasKey('model', $actual); + self::assertSame('whisper-1', $actual['model']); + self::assertArrayHasKey('file', $actual); + self::assertTrue(\is_resource($actual['file'])); + } +} diff --git a/src/platform/tests/Message/AssistantMessageTest.php b/src/platform/tests/Message/AssistantMessageTest.php new file mode 100644 index 000000000..130c00f9f --- /dev/null +++ b/src/platform/tests/Message/AssistantMessageTest.php @@ -0,0 +1,53 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Message; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Message\AssistantMessage; +use Symfony\AI\Platform\Message\Role; +use Symfony\AI\Platform\Response\ToolCall; + +#[CoversClass(AssistantMessage::class)] +#[UsesClass(ToolCall::class)] +#[Small] +final class AssistantMessageTest extends TestCase +{ + #[Test] + public function theRoleOfTheMessageIsAsExpected(): void + { + self::assertSame(Role::Assistant, (new AssistantMessage())->getRole()); + } + + #[Test] + public function constructionWithoutToolCallIsPossible(): void + { + $message = new AssistantMessage('foo'); + + self::assertSame('foo', $message->content); + self::assertNull($message->toolCalls); + } + + #[Test] + public function constructionWithoutContentIsPossible(): void + { + $toolCall = new ToolCall('foo', 'foo'); + $message = new AssistantMessage(toolCalls: [$toolCall]); + + self::assertNull($message->content); + self::assertSame([$toolCall], $message->toolCalls); + self::assertTrue($message->hasToolCalls()); + } +} diff --git a/src/platform/tests/Message/Content/AudioTest.php b/src/platform/tests/Message/Content/AudioTest.php new file mode 100644 index 000000000..04365694b --- /dev/null +++ b/src/platform/tests/Message/Content/AudioTest.php @@ -0,0 +1,69 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Message\Content; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Message\Content\Audio; + +#[CoversClass(Audio::class)] +#[Small] +final class AudioTest extends TestCase +{ + #[Test] + public function constructWithValidData(): void + { + $audio = new Audio('somedata', 'audio/mpeg'); + + self::assertSame('somedata', $audio->asBinary()); + self::assertSame('audio/mpeg', $audio->getFormat()); + } + + #[Test] + public function fromDataUrlWithValidUrl(): void + { + $dataUrl = 'data:audio/mpeg;base64,SUQzBAAAAAAAfVREUkMAAAAMAAADMg=='; + $audio = Audio::fromDataUrl($dataUrl); + + self::assertSame('SUQzBAAAAAAAfVREUkMAAAAMAAADMg==', $audio->asBase64()); + self::assertSame('audio/mpeg', $audio->getFormat()); + } + + #[Test] + public function fromDataUrlWithInvalidUrl(): void + { + self::expectException(\InvalidArgumentException::class); + self::expectExceptionMessage('Invalid audio data URL format.'); + + Audio::fromDataUrl('invalid-url'); + } + + #[Test] + public function fromFileWithValidPath(): void + { + $audio = Audio::fromFile(\dirname(__DIR__, 3).'/Fixture/audio.mp3'); + + self::assertSame('audio/mpeg', $audio->getFormat()); + self::assertNotEmpty($audio->asBinary()); + } + + #[Test] + public function fromFileWithInvalidPath(): void + { + self::expectException(\InvalidArgumentException::class); + self::expectExceptionMessage('The file "foo.mp3" does not exist or is not readable.'); + + Audio::fromFile('foo.mp3'); + } +} diff --git a/src/platform/tests/Message/Content/BinaryTest.php b/src/platform/tests/Message/Content/BinaryTest.php new file mode 100644 index 000000000..dcc75b4df --- /dev/null +++ b/src/platform/tests/Message/Content/BinaryTest.php @@ -0,0 +1,115 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Message\Content; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\DataProvider; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Exception\InvalidArgumentException; +use Symfony\AI\Platform\Message\Content\File; + +#[CoversClass(File::class)] +#[Small] +final class BinaryTest extends TestCase +{ + #[Test] + public function createFromDataUrl(): void + { + $dataUrl = ''; + + $binary = File::fromDataUrl($dataUrl); + + self::assertSame('image/png', $binary->getFormat()); + self::assertNotEmpty($binary->asBinary()); + self::assertSame('iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=', $binary->asBase64()); + } + + #[Test] + public function throwsExceptionForInvalidDataUrl(): void + { + self::expectException(InvalidArgumentException::class); + self::expectExceptionMessage('Invalid audio data URL format.'); + + File::fromDataUrl('invalid-data-url'); + } + + #[Test] + public function createFromFile(): void + { + $content = 'test file content'; + $filename = sys_get_temp_dir().'/binary-test-file.txt'; + file_put_contents($filename, $content); + + try { + $binary = File::fromFile($filename); + + self::assertSame('text/plain', $binary->getFormat()); + self::assertSame($content, $binary->asBinary()); + } finally { + unlink($filename); + } + } + + #[Test] + #[DataProvider('provideExistingFiles')] + public function createFromExistingFiles(string $filePath, string $expectedFormat): void + { + $binary = File::fromFile($filePath); + + self::assertSame($expectedFormat, $binary->getFormat()); + self::assertNotEmpty($binary->asBinary()); + } + + /** + * @return iterable + */ + public static function provideExistingFiles(): iterable + { + yield 'mp3' => [\dirname(__DIR__, 3).'/Fixture/audio.mp3', 'audio/mpeg']; + yield 'jpg' => [\dirname(__DIR__, 3).'/Fixture/image.jpg', 'image/jpeg']; + } + + #[Test] + public function throwsExceptionForNonExistentFile(): void + { + self::expectException(\InvalidArgumentException::class); + + File::fromFile('/non/existent/file.jpg'); + } + + #[Test] + public function convertToDataUrl(): void + { + $data = 'Hello World'; + $format = 'text/plain'; + $binary = new File($data, $format); + + $dataUrl = $binary->asDataUrl(); + + self::assertSame('data:text/plain;base64,'.base64_encode($data), $dataUrl); + } + + #[Test] + public function roundTripConversion(): void + { + $originalDataUrl = 'data:application/pdf;base64,JVBERi0xLjQKJcfsj6IKNSAwIG9iago8PC9MZW5ndGggNiAwIFIvRmls'; + + $binary = File::fromDataUrl($originalDataUrl); + $resultDataUrl = $binary->asDataUrl(); + + self::assertSame($originalDataUrl, $resultDataUrl); + self::assertSame('application/pdf', $binary->getFormat()); + self::assertSame('JVBERi0xLjQKJcfsj6IKNSAwIG9iago8PC9MZW5ndGggNiAwIFIvRmls', $binary->asBase64()); + } +} diff --git a/src/platform/tests/Message/Content/ImageTest.php b/src/platform/tests/Message/Content/ImageTest.php new file mode 100644 index 000000000..82c099885 --- /dev/null +++ b/src/platform/tests/Message/Content/ImageTest.php @@ -0,0 +1,45 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Message\Content; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Message\Content\Image; + +#[CoversClass(Image::class)] +final class ImageTest extends TestCase +{ + #[Test] + public function constructWithValidDataUrl(): void + { + $image = Image::fromDataUrl(''); + + self::assertStringStartsWith('data:image/png;base64', $image->asDataUrl()); + } + + #[Test] + public function withValidFile(): void + { + $image = Image::fromFile(\dirname(__DIR__, 3).'/Fixture/image.jpg'); + + self::assertStringStartsWith('data:image/jpeg;base64,', $image->asDataUrl()); + } + + #[Test] + public function fromBinaryWithInvalidFile(): void + { + self::expectExceptionMessage('The file "foo.jpg" does not exist or is not readable.'); + + Image::fromFile('foo.jpg'); + } +} diff --git a/src/platform/tests/Message/Content/ImageUrlTest.php b/src/platform/tests/Message/Content/ImageUrlTest.php new file mode 100644 index 000000000..3455d56c0 --- /dev/null +++ b/src/platform/tests/Message/Content/ImageUrlTest.php @@ -0,0 +1,31 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Message\Content; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Message\Content\ImageUrl; + +#[CoversClass(ImageUrl::class)] +#[Small] +final class ImageUrlTest extends TestCase +{ + #[Test] + public function constructWithValidUrl(): void + { + $image = new ImageUrl('https://foo.com/test.png'); + + self::assertSame('https://foo.com/test.png', $image->url); + } +} diff --git a/src/platform/tests/Message/Content/TextTest.php b/src/platform/tests/Message/Content/TextTest.php new file mode 100644 index 000000000..2f791c08a --- /dev/null +++ b/src/platform/tests/Message/Content/TextTest.php @@ -0,0 +1,31 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Message\Content; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Message\Content\Text; + +#[CoversClass(Text::class)] +#[Small] +final class TextTest extends TestCase +{ + #[Test] + public function constructionIsPossible(): void + { + $obj = new Text('foo'); + + self::assertSame('foo', $obj->text); + } +} diff --git a/src/platform/tests/Message/MessageBagTest.php b/src/platform/tests/Message/MessageBagTest.php new file mode 100644 index 000000000..17a0c5a63 --- /dev/null +++ b/src/platform/tests/Message/MessageBagTest.php @@ -0,0 +1,178 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Message; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Message\AssistantMessage; +use Symfony\AI\Platform\Message\Content\ImageUrl; +use Symfony\AI\Platform\Message\Content\Text; +use Symfony\AI\Platform\Message\Message; +use Symfony\AI\Platform\Message\MessageBag; +use Symfony\AI\Platform\Message\SystemMessage; +use Symfony\AI\Platform\Message\ToolCallMessage; +use Symfony\AI\Platform\Message\UserMessage; +use Symfony\AI\Platform\Response\ToolCall; + +#[CoversClass(MessageBag::class)] +#[UsesClass(Message::class)] +#[UsesClass(UserMessage::class)] +#[UsesClass(SystemMessage::class)] +#[UsesClass(AssistantMessage::class)] +#[UsesClass(ImageUrl::class)] +#[UsesClass(Text::class)] +#[UsesClass(ToolCall::class)] +#[UsesClass(ToolCallMessage::class)] +#[Small] +final class MessageBagTest extends TestCase +{ + #[Test] + public function getSystemMessage(): void + { + $messageBag = new MessageBag( + Message::forSystem('My amazing system prompt.'), + Message::ofAssistant('It is time to sleep.'), + Message::ofUser('Hello, world!'), + Message::ofToolCall(new ToolCall('tool', 'tool_name', ['param' => 'value']), 'Yes, go sleeping.'), + ); + + $systemMessage = $messageBag->getSystemMessage(); + + self::assertSame('My amazing system prompt.', $systemMessage->content); + } + + #[Test] + public function getSystemMessageWithoutSystemMessage(): void + { + $messageBag = new MessageBag( + Message::ofAssistant('It is time to sleep.'), + Message::ofUser('Hello, world!'), + Message::ofToolCall(new ToolCall('tool', 'tool_name', ['param' => 'value']), 'Yes, go sleeping.'), + ); + + self::assertNull($messageBag->getSystemMessage()); + } + + #[Test] + public function with(): void + { + $messageBag = new MessageBag( + Message::forSystem('My amazing system prompt.'), + Message::ofAssistant('It is time to sleep.'), + Message::ofUser('Hello, world!'), + ); + + $newMessage = Message::ofAssistant('It is time to wake up.'); + $newMessageBag = $messageBag->with($newMessage); + + self::assertCount(3, $messageBag); + self::assertCount(4, $newMessageBag); + + $newMessageFromBag = $newMessageBag->getMessages()[3]; + + self::assertInstanceOf(AssistantMessage::class, $newMessageFromBag); + self::assertSame('It is time to wake up.', $newMessageFromBag->content); + } + + #[Test] + public function merge(): void + { + $messageBag = new MessageBag( + Message::forSystem('My amazing system prompt.'), + Message::ofAssistant('It is time to sleep.'), + Message::ofUser('Hello, world!'), + ); + + $messageBag = $messageBag->merge(new MessageBag( + Message::ofAssistant('It is time to wake up.') + )); + + self::assertCount(4, $messageBag); + + $messageFromBag = $messageBag->getMessages()[3]; + + self::assertInstanceOf(AssistantMessage::class, $messageFromBag); + self::assertSame('It is time to wake up.', $messageFromBag->content); + } + + #[Test] + public function withoutSystemMessage(): void + { + $messageBag = new MessageBag( + Message::forSystem('My amazing system prompt.'), + Message::ofAssistant('It is time to sleep.'), + Message::forSystem('A system prompt in the middle.'), + Message::ofUser('Hello, world!'), + Message::forSystem('Another system prompt at the end'), + ); + + $newMessageBag = $messageBag->withoutSystemMessage(); + + self::assertCount(5, $messageBag); + self::assertCount(2, $newMessageBag); + + $assistantMessage = $newMessageBag->getMessages()[0]; + self::assertInstanceOf(AssistantMessage::class, $assistantMessage); + self::assertSame('It is time to sleep.', $assistantMessage->content); + + $userMessage = $newMessageBag->getMessages()[1]; + self::assertInstanceOf(UserMessage::class, $userMessage); + self::assertInstanceOf(Text::class, $userMessage->content[0]); + self::assertSame('Hello, world!', $userMessage->content[0]->text); + } + + #[Test] + public function prepend(): void + { + $messageBag = new MessageBag( + Message::ofAssistant('It is time to sleep.'), + Message::ofUser('Hello, world!'), + ); + + $newMessage = Message::forSystem('My amazing system prompt.'); + $newMessageBag = $messageBag->prepend($newMessage); + + self::assertCount(2, $messageBag); + self::assertCount(3, $newMessageBag); + + $newMessageBagMessage = $newMessageBag->getMessages()[0]; + + self::assertInstanceOf(SystemMessage::class, $newMessageBagMessage); + self::assertSame('My amazing system prompt.', $newMessageBagMessage->content); + } + + #[Test] + public function containsImageReturnsFalseWithoutImage(): void + { + $messageBag = new MessageBag( + Message::ofAssistant('It is time to sleep.'), + Message::ofUser('Hello, world!'), + ); + + self::assertFalse($messageBag->containsImage()); + } + + #[Test] + public function containsImageReturnsTrueWithImage(): void + { + $messageBag = new MessageBag( + Message::ofAssistant('It is time to sleep.'), + Message::ofUser('Hello, world!'), + Message::ofUser('My hint for how to analyze an image.', new ImageUrl('http://image-generator.local/my-fancy-image.png')), + ); + + self::assertTrue($messageBag->containsImage()); + } +} diff --git a/src/platform/tests/Message/MessageTest.php b/src/platform/tests/Message/MessageTest.php new file mode 100644 index 000000000..a1e3a7d55 --- /dev/null +++ b/src/platform/tests/Message/MessageTest.php @@ -0,0 +1,111 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Message; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Message\AssistantMessage; +use Symfony\AI\Platform\Message\Content\ImageUrl; +use Symfony\AI\Platform\Message\Content\Text; +use Symfony\AI\Platform\Message\Message; +use Symfony\AI\Platform\Message\Role; +use Symfony\AI\Platform\Message\SystemMessage; +use Symfony\AI\Platform\Message\ToolCallMessage; +use Symfony\AI\Platform\Message\UserMessage; +use Symfony\AI\Platform\Response\ToolCall; + +#[CoversClass(Message::class)] +#[UsesClass(UserMessage::class)] +#[UsesClass(SystemMessage::class)] +#[UsesClass(AssistantMessage::class)] +#[UsesClass(ToolCallMessage::class)] +#[UsesClass(Role::class)] +#[UsesClass(ToolCall::class)] +#[UsesClass(ImageUrl::class)] +#[UsesClass(Text::class)] +#[Small] +final class MessageTest extends TestCase +{ + #[Test] + public function createSystemMessage(): void + { + $message = Message::forSystem('My amazing system prompt.'); + + self::assertSame('My amazing system prompt.', $message->content); + } + + #[Test] + public function createAssistantMessage(): void + { + $message = Message::ofAssistant('It is time to sleep.'); + + self::assertSame('It is time to sleep.', $message->content); + } + + #[Test] + public function createAssistantMessageWithToolCalls(): void + { + $toolCalls = [ + new ToolCall('call_123456', 'my_tool', ['foo' => 'bar']), + new ToolCall('call_456789', 'my_faster_tool'), + ]; + $message = Message::ofAssistant(toolCalls: $toolCalls); + + self::assertCount(2, $message->toolCalls); + self::assertTrue($message->hasToolCalls()); + } + + #[Test] + public function createUserMessage(): void + { + $message = Message::ofUser('Hi, my name is John.'); + + self::assertCount(1, $message->content); + self::assertInstanceOf(Text::class, $message->content[0]); + self::assertSame('Hi, my name is John.', $message->content[0]->text); + } + + #[Test] + public function createUserMessageWithTextContent(): void + { + $text = new Text('Hi, my name is John.'); + $message = Message::ofUser($text); + + self::assertSame([$text], $message->content); + } + + #[Test] + public function createUserMessageWithImages(): void + { + $message = Message::ofUser( + new Text('Hi, my name is John.'), + new ImageUrl('http://images.local/my-image.png'), + 'The following image is a joke.', + new ImageUrl('http://images.local/my-image2.png'), + ); + + self::assertCount(4, $message->content); + } + + #[Test] + public function createToolCallMessage(): void + { + $toolCall = new ToolCall('call_123456', 'my_tool', ['foo' => 'bar']); + $message = Message::ofToolCall($toolCall, 'Foo bar.'); + + self::assertSame('Foo bar.', $message->content); + self::assertSame($toolCall, $message->toolCall); + } +} diff --git a/src/platform/tests/Message/RoleTest.php b/src/platform/tests/Message/RoleTest.php new file mode 100644 index 000000000..bc74d4c68 --- /dev/null +++ b/src/platform/tests/Message/RoleTest.php @@ -0,0 +1,56 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Message; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Message\Role; + +#[CoversClass(Role::class)] +#[Small] +final class RoleTest extends TestCase +{ + #[Test] + public function values(): void + { + self::assertSame('system', Role::System->value); + self::assertSame('assistant', Role::Assistant->value); + self::assertSame('user', Role::User->value); + self::assertSame('tool', Role::ToolCall->value); + } + + #[Test] + public function equals(): void + { + self::assertTrue(Role::System->equals(Role::System)); + } + + #[Test] + public function notEquals(): void + { + self::assertTrue(Role::System->notEquals(Role::Assistant)); + } + + #[Test] + public function notEqualsOneOf(): void + { + self::assertTrue(Role::System->notEqualsOneOf([Role::Assistant, Role::User])); + } + + #[Test] + public function equalsOneOf(): void + { + self::assertTrue(Role::System->equalsOneOf([Role::System, Role::User])); + } +} diff --git a/src/platform/tests/Message/SystemMessageTest.php b/src/platform/tests/Message/SystemMessageTest.php new file mode 100644 index 000000000..160f79bc4 --- /dev/null +++ b/src/platform/tests/Message/SystemMessageTest.php @@ -0,0 +1,33 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Message; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Message\Role; +use Symfony\AI\Platform\Message\SystemMessage; + +#[CoversClass(SystemMessage::class)] +#[Small] +final class SystemMessageTest extends TestCase +{ + #[Test] + public function constructionIsPossible(): void + { + $message = new SystemMessage('foo'); + + self::assertSame(Role::System, $message->getRole()); + self::assertSame('foo', $message->content); + } +} diff --git a/src/platform/tests/Message/ToolCallMessageTest.php b/src/platform/tests/Message/ToolCallMessageTest.php new file mode 100644 index 000000000..015cd1070 --- /dev/null +++ b/src/platform/tests/Message/ToolCallMessageTest.php @@ -0,0 +1,36 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Message; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Message\ToolCallMessage; +use Symfony\AI\Platform\Response\ToolCall; + +#[CoversClass(ToolCallMessage::class)] +#[UsesClass(ToolCall::class)] +#[Small] +final class ToolCallMessageTest extends TestCase +{ + #[Test] + public function constructionIsPossible(): void + { + $toolCall = new ToolCall('foo', 'bar'); + $obj = new ToolCallMessage($toolCall, 'bar'); + + self::assertSame($toolCall, $obj->toolCall); + self::assertSame('bar', $obj->content); + } +} diff --git a/src/platform/tests/Message/UserMessageTest.php b/src/platform/tests/Message/UserMessageTest.php new file mode 100644 index 000000000..4d94d6a01 --- /dev/null +++ b/src/platform/tests/Message/UserMessageTest.php @@ -0,0 +1,83 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Message; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Message\Content\Audio; +use Symfony\AI\Platform\Message\Content\ImageUrl; +use Symfony\AI\Platform\Message\Content\Text; +use Symfony\AI\Platform\Message\Role; +use Symfony\AI\Platform\Message\UserMessage; + +#[CoversClass(UserMessage::class)] +#[UsesClass(Text::class)] +#[UsesClass(Audio::class)] +#[UsesClass(ImageUrl::class)] +#[UsesClass(Role::class)] +#[Small] +final class UserMessageTest extends TestCase +{ + #[Test] + public function constructionIsPossible(): void + { + $obj = new UserMessage(new Text('foo')); + + self::assertSame(Role::User, $obj->getRole()); + self::assertCount(1, $obj->content); + self::assertInstanceOf(Text::class, $obj->content[0]); + self::assertSame('foo', $obj->content[0]->text); + } + + #[Test] + public function constructionIsPossibleWithMultipleContent(): void + { + $message = new UserMessage(new Text('foo'), new ImageUrl('https://foo.com/bar.jpg')); + + self::assertCount(2, $message->content); + } + + #[Test] + public function hasAudioContentWithoutAudio(): void + { + $message = new UserMessage(new Text('foo'), new Text('bar')); + + self::assertFalse($message->hasAudioContent()); + } + + #[Test] + public function hasAudioContentWithAudio(): void + { + $message = new UserMessage(new Text('foo'), Audio::fromFile(\dirname(__DIR__, 2).'/Fixture/audio.mp3')); + + self::assertTrue($message->hasAudioContent()); + } + + #[Test] + public function hasImageContentWithoutImage(): void + { + $message = new UserMessage(new Text('foo'), new Text('bar')); + + self::assertFalse($message->hasImageContent()); + } + + #[Test] + public function hasImageContentWithImage(): void + { + $message = new UserMessage(new Text('foo'), new ImageUrl('https://foo.com/bar.jpg')); + + self::assertTrue($message->hasImageContent()); + } +} diff --git a/src/platform/tests/ModelTest.php b/src/platform/tests/ModelTest.php new file mode 100644 index 000000000..30ae5fcba --- /dev/null +++ b/src/platform/tests/ModelTest.php @@ -0,0 +1,80 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Capability; +use Symfony\AI\Platform\Model; + +#[CoversClass(Model::class)] +#[Small] +#[UsesClass(Capability::class)] +final class ModelTest extends TestCase +{ + #[Test] + public function returnsName(): void + { + $model = new Model('gpt-4'); + + self::assertSame('gpt-4', $model->getName()); + } + + #[Test] + public function returnsCapabilities(): void + { + $model = new Model('gpt-4', [Capability::INPUT_TEXT, Capability::OUTPUT_TEXT]); + + self::assertSame([Capability::INPUT_TEXT, Capability::OUTPUT_TEXT], $model->getCapabilities()); + } + + #[Test] + public function checksSupportForCapability(): void + { + $model = new Model('gpt-4', [Capability::INPUT_TEXT, Capability::OUTPUT_TEXT]); + + self::assertTrue($model->supports(Capability::INPUT_TEXT)); + self::assertTrue($model->supports(Capability::OUTPUT_TEXT)); + self::assertFalse($model->supports(Capability::INPUT_IMAGE)); + } + + #[Test] + public function returnsEmptyCapabilitiesByDefault(): void + { + $model = new Model('gpt-4'); + + self::assertSame([], $model->getCapabilities()); + } + + #[Test] + public function returnsOptions(): void + { + $options = [ + 'temperature' => 0.7, + 'max_tokens' => 1024, + ]; + $model = new Model('gpt-4', [], $options); + + self::assertSame($options, $model->getOptions()); + } + + #[Test] + public function returnsEmptyOptionsByDefault(): void + { + $model = new Model('gpt-4'); + + self::assertSame([], $model->getOptions()); + } +} diff --git a/src/platform/tests/Response/AsyncResponseTest.php b/src/platform/tests/Response/AsyncResponseTest.php new file mode 100644 index 000000000..2c906ecd4 --- /dev/null +++ b/src/platform/tests/Response/AsyncResponseTest.php @@ -0,0 +1,169 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Response; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Response\AsyncResponse; +use Symfony\AI\Platform\Response\BaseResponse; +use Symfony\AI\Platform\Response\Exception\RawResponseAlreadySetException; +use Symfony\AI\Platform\Response\Metadata\Metadata; +use Symfony\AI\Platform\Response\ResponseInterface; +use Symfony\AI\Platform\Response\TextResponse; +use Symfony\AI\Platform\ResponseConverterInterface; +use Symfony\Contracts\HttpClient\ResponseInterface as SymfonyHttpResponse; + +#[CoversClass(AsyncResponse::class)] +#[UsesClass(Metadata::class)] +#[UsesClass(TextResponse::class)] +#[UsesClass(RawResponseAlreadySetException::class)] +#[Small] +final class AsyncResponseTest extends TestCase +{ + #[Test] + public function itUnwrapsTheResponseWhenGettingContent(): void + { + $httpResponse = $this->createStub(SymfonyHttpResponse::class); + $textResponse = new TextResponse('test content'); + + $responseConverter = self::createMock(ResponseConverterInterface::class); + $responseConverter->expects(self::once()) + ->method('convert') + ->with($httpResponse, []) + ->willReturn($textResponse); + + $asyncResponse = new AsyncResponse($responseConverter, $httpResponse); + + self::assertSame('test content', $asyncResponse->getContent()); + } + + #[Test] + public function itConvertsTheResponseOnlyOnce(): void + { + $httpResponse = $this->createStub(SymfonyHttpResponse::class); + $textResponse = new TextResponse('test content'); + + $responseConverter = self::createMock(ResponseConverterInterface::class); + $responseConverter->expects(self::once()) + ->method('convert') + ->with($httpResponse, []) + ->willReturn($textResponse); + + $asyncResponse = new AsyncResponse($responseConverter, $httpResponse); + + // Call unwrap multiple times, but the converter should only be called once + $asyncResponse->unwrap(); + $asyncResponse->unwrap(); + $asyncResponse->getContent(); + } + + #[Test] + public function itGetsRawResponseDirectly(): void + { + $httpResponse = $this->createStub(SymfonyHttpResponse::class); + $responseConverter = $this->createStub(ResponseConverterInterface::class); + + $asyncResponse = new AsyncResponse($responseConverter, $httpResponse); + + self::assertSame($httpResponse, $asyncResponse->getRawResponse()); + } + + #[Test] + public function itThrowsExceptionWhenSettingRawResponse(): void + { + self::expectException(RawResponseAlreadySetException::class); + + $httpResponse = $this->createStub(SymfonyHttpResponse::class); + $responseConverter = $this->createStub(ResponseConverterInterface::class); + + $asyncResponse = new AsyncResponse($responseConverter, $httpResponse); + $asyncResponse->setRawResponse($httpResponse); + } + + #[Test] + public function itSetsRawResponseOnUnwrappedResponseWhenNeeded(): void + { + $httpResponse = $this->createStub(SymfonyHttpResponse::class); + + $unwrappedResponse = $this->createResponse(null); + + $responseConverter = $this->createStub(ResponseConverterInterface::class); + $responseConverter->method('convert')->willReturn($unwrappedResponse); + + $asyncResponse = new AsyncResponse($responseConverter, $httpResponse); + $asyncResponse->unwrap(); + + // The raw response in the model response is now set and not null anymore + self::assertSame($httpResponse, $unwrappedResponse->getRawResponse()); + } + + #[Test] + public function itDoesNotSetRawResponseOnUnwrappedResponseWhenAlreadySet(): void + { + $originHttpResponse = $this->createStub(SymfonyHttpResponse::class); + $anotherHttpResponse = $this->createStub(SymfonyHttpResponse::class); + + $unwrappedResponse = $this->createResponse($anotherHttpResponse); + + $responseConverter = $this->createStub(ResponseConverterInterface::class); + $responseConverter->method('convert')->willReturn($unwrappedResponse); + + $asyncResponse = new AsyncResponse($responseConverter, $originHttpResponse); + $asyncResponse->unwrap(); + + // It is still the same raw response as set initially and so not overwritten + self::assertSame($anotherHttpResponse, $unwrappedResponse->getRawResponse()); + } + + /** + * Workaround for low deps because mocking the ResponseInterface leads to an exception with + * mock creation "Type Traversable|object|array|string|null contains both object and a class type" + * in PHPUnit MockClass. + */ + private function createResponse(?SymfonyHttpResponse $rawResponse): ResponseInterface + { + return new class($rawResponse) extends BaseResponse { + public function __construct(protected ?SymfonyHttpResponse $rawResponse) + { + } + + public function getContent(): string + { + return 'test content'; + } + + public function getRawResponse(): ?SymfonyHttpResponse + { + return $this->rawResponse; + } + }; + } + + #[Test] + public function itPassesOptionsToConverter(): void + { + $httpResponse = $this->createStub(SymfonyHttpResponse::class); + $options = ['option1' => 'value1', 'option2' => 'value2']; + + $responseConverter = self::createMock(ResponseConverterInterface::class); + $responseConverter->expects(self::once()) + ->method('convert') + ->with($httpResponse, $options) + ->willReturn($this->createResponse(null)); + + $asyncResponse = new AsyncResponse($responseConverter, $httpResponse, $options); + $asyncResponse->unwrap(); + } +} diff --git a/src/platform/tests/Response/BaseResponseTest.php b/src/platform/tests/Response/BaseResponseTest.php new file mode 100644 index 000000000..cd4afdbe6 --- /dev/null +++ b/src/platform/tests/Response/BaseResponseTest.php @@ -0,0 +1,80 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Response; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\Attributes\UsesTrait; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Response\BaseResponse; +use Symfony\AI\Platform\Response\Exception\RawResponseAlreadySetException; +use Symfony\AI\Platform\Response\Metadata\Metadata; +use Symfony\AI\Platform\Response\Metadata\MetadataAwareTrait; +use Symfony\AI\Platform\Response\RawResponseAwareTrait; +use Symfony\Contracts\HttpClient\ResponseInterface as SymfonyHttpResponse; + +#[CoversClass(BaseResponse::class)] +#[UsesTrait(MetadataAwareTrait::class)] +#[UsesTrait(RawResponseAwareTrait::class)] +#[UsesClass(Metadata::class)] +#[UsesClass(RawResponseAlreadySetException::class)] +#[Small] +final class BaseResponseTest extends TestCase +{ + #[Test] + public function itCanHandleMetadata(): void + { + $response = $this->createResponse(); + $metadata = $response->getMetadata(); + + self::assertCount(0, $metadata); + + $metadata->add('key', 'value'); + $metadata = $response->getMetadata(); + + self::assertCount(1, $metadata); + } + + #[Test] + public function itCanBeEnrichedWithARawResponse(): void + { + $response = $this->createResponse(); + $rawResponse = self::createMock(SymfonyHttpResponse::class); + + $response->setRawResponse($rawResponse); + self::assertSame($rawResponse, $response->getRawResponse()); + } + + #[Test] + public function itThrowsAnExceptionWhenSettingARawResponseTwice(): void + { + self::expectException(RawResponseAlreadySetException::class); + + $response = $this->createResponse(); + $rawResponse = self::createMock(SymfonyHttpResponse::class); + + $response->setRawResponse($rawResponse); + $response->setRawResponse($rawResponse); + } + + private function createResponse(): BaseResponse + { + return new class extends BaseResponse { + public function getContent(): string + { + return 'test'; + } + }; + } +} diff --git a/src/platform/tests/Response/ChoiceResponseTest.php b/src/platform/tests/Response/ChoiceResponseTest.php new file mode 100644 index 000000000..f2938cd67 --- /dev/null +++ b/src/platform/tests/Response/ChoiceResponseTest.php @@ -0,0 +1,50 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Response; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Exception\InvalidArgumentException; +use Symfony\AI\Platform\Response\Choice; +use Symfony\AI\Platform\Response\ChoiceResponse; + +#[CoversClass(ChoiceResponse::class)] +#[UsesClass(Choice::class)] +#[Small] +final class ChoiceResponseTest extends TestCase +{ + #[Test] + public function choiceResponseCreation(): void + { + $choice1 = new Choice('choice1'); + $choice2 = new Choice(null); + $choice3 = new Choice('choice3'); + $response = new ChoiceResponse($choice1, $choice2, $choice3); + + self::assertCount(3, $response->getContent()); + self::assertSame('choice1', $response->getContent()[0]->getContent()); + self::assertNull($response->getContent()[1]->getContent()); + self::assertSame('choice3', $response->getContent()[2]->getContent()); + } + + #[Test] + public function choiceResponseWithNoChoices(): void + { + self::expectException(InvalidArgumentException::class); + self::expectExceptionMessage('Response must have at least one choice.'); + + new ChoiceResponse(); + } +} diff --git a/src/platform/tests/Response/ChoiceTest.php b/src/platform/tests/Response/ChoiceTest.php new file mode 100644 index 000000000..8fcff2324 --- /dev/null +++ b/src/platform/tests/Response/ChoiceTest.php @@ -0,0 +1,66 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Response; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Response\Choice; +use Symfony\AI\Platform\Response\ToolCall; + +#[CoversClass(Choice::class)] +#[UsesClass(ToolCall::class)] +#[Small] +final class ChoiceTest extends TestCase +{ + #[Test] + public function choiceEmpty(): void + { + $choice = new Choice(); + self::assertFalse($choice->hasContent()); + self::assertNull($choice->getContent()); + self::assertFalse($choice->hasToolCall()); + self::assertCount(0, $choice->getToolCalls()); + } + + #[Test] + public function choiceWithContent(): void + { + $choice = new Choice('content'); + self::assertTrue($choice->hasContent()); + self::assertSame('content', $choice->getContent()); + self::assertFalse($choice->hasToolCall()); + self::assertCount(0, $choice->getToolCalls()); + } + + #[Test] + public function choiceWithToolCall(): void + { + $choice = new Choice(null, [new ToolCall('name', 'arguments')]); + self::assertFalse($choice->hasContent()); + self::assertNull($choice->getContent()); + self::assertTrue($choice->hasToolCall()); + self::assertCount(1, $choice->getToolCalls()); + } + + #[Test] + public function choiceWithContentAndToolCall(): void + { + $choice = new Choice('content', [new ToolCall('name', 'arguments')]); + self::assertTrue($choice->hasContent()); + self::assertSame('content', $choice->getContent()); + self::assertTrue($choice->hasToolCall()); + self::assertCount(1, $choice->getToolCalls()); + } +} diff --git a/src/platform/tests/Response/Exception/RawResponseAlreadySetTest.php b/src/platform/tests/Response/Exception/RawResponseAlreadySetTest.php new file mode 100644 index 000000000..1cb275586 --- /dev/null +++ b/src/platform/tests/Response/Exception/RawResponseAlreadySetTest.php @@ -0,0 +1,31 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Response\Exception; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Response\Exception\RawResponseAlreadySetException; + +#[CoversClass(RawResponseAlreadySetException::class)] +#[Small] +final class RawResponseAlreadySetTest extends TestCase +{ + #[Test] + public function itHasCorrectExceptionMessage(): void + { + $exception = new RawResponseAlreadySetException(); + + self::assertSame('The raw response was already set.', $exception->getMessage()); + } +} diff --git a/src/platform/tests/Response/Metadata/MetadataAwareTraitTest.php b/src/platform/tests/Response/Metadata/MetadataAwareTraitTest.php new file mode 100644 index 000000000..5edb33372 --- /dev/null +++ b/src/platform/tests/Response/Metadata/MetadataAwareTraitTest.php @@ -0,0 +1,47 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Response\Metadata; + +use PHPUnit\Framework\Attributes\CoversTrait; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Response\Metadata\Metadata; +use Symfony\AI\Platform\Response\Metadata\MetadataAwareTrait; + +#[CoversTrait(MetadataAwareTrait::class)] +#[Small] +#[UsesClass(Metadata::class)] +final class MetadataAwareTraitTest extends TestCase +{ + #[Test] + public function itCanHandleMetadata(): void + { + $response = $this->createTestClass(); + $metadata = $response->getMetadata(); + + self::assertCount(0, $metadata); + + $metadata->add('key', 'value'); + $metadata = $response->getMetadata(); + + self::assertCount(1, $metadata); + } + + private function createTestClass(): object + { + return new class { + use MetadataAwareTrait; + }; + } +} diff --git a/src/platform/tests/Response/Metadata/MetadataTest.php b/src/platform/tests/Response/Metadata/MetadataTest.php new file mode 100644 index 000000000..f55e62b8a --- /dev/null +++ b/src/platform/tests/Response/Metadata/MetadataTest.php @@ -0,0 +1,137 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Response\Metadata; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Response\Metadata\Metadata; + +#[CoversClass(Metadata::class)] +#[Small] +final class MetadataTest extends TestCase +{ + #[Test] + public function itCanBeCreatedEmpty(): void + { + $metadata = new Metadata(); + self::assertCount(0, $metadata); + self::assertSame([], $metadata->all()); + } + + #[Test] + public function itCanBeCreatedWithInitialData(): void + { + $metadata = new Metadata(['key' => 'value']); + self::assertCount(1, $metadata); + self::assertSame(['key' => 'value'], $metadata->all()); + } + + #[Test] + public function itCanAddNewMetadata(): void + { + $metadata = new Metadata(); + $metadata->add('key', 'value'); + + self::assertTrue($metadata->has('key')); + self::assertSame('value', $metadata->get('key')); + } + + #[Test] + public function itCanCheckIfMetadataExists(): void + { + $metadata = new Metadata(['key' => 'value']); + + self::assertTrue($metadata->has('key')); + self::assertFalse($metadata->has('nonexistent')); + } + + #[Test] + public function itCanGetMetadataWithDefault(): void + { + $metadata = new Metadata(['key' => 'value']); + + self::assertSame('value', $metadata->get('key')); + self::assertSame('default', $metadata->get('nonexistent', 'default')); + self::assertNull($metadata->get('nonexistent')); + } + + #[Test] + public function itCanRemoveMetadata(): void + { + $metadata = new Metadata(['key' => 'value']); + self::assertTrue($metadata->has('key')); + + $metadata->remove('key'); + self::assertFalse($metadata->has('key')); + } + + #[Test] + public function itCanSetEntireMetadataArray(): void + { + $metadata = new Metadata(['key1' => 'value1']); + $metadata->set(['key2' => 'value2', 'key3' => 'value3']); + + self::assertFalse($metadata->has('key1')); + self::assertTrue($metadata->has('key2')); + self::assertTrue($metadata->has('key3')); + self::assertSame(['key2' => 'value2', 'key3' => 'value3'], $metadata->all()); + } + + #[Test] + public function itImplementsJsonSerializable(): void + { + $metadata = new Metadata(['key' => 'value']); + self::assertSame(['key' => 'value'], $metadata->jsonSerialize()); + } + + #[Test] + public function itImplementsArrayAccess(): void + { + $metadata = new Metadata(['key' => 'value']); + + self::assertArrayHasKey('key', $metadata); + self::assertSame('value', $metadata['key']); + + $metadata['new'] = 'newValue'; + self::assertSame('newValue', $metadata['new']); + + unset($metadata['key']); + self::assertArrayNotHasKey('key', $metadata); + } + + #[Test] + public function itImplementsIteratorAggregate(): void + { + $metadata = new Metadata(['key1' => 'value1', 'key2' => 'value2']); + $result = iterator_to_array($metadata); + + self::assertSame(['key1' => 'value1', 'key2' => 'value2'], $result); + } + + #[Test] + public function itImplementsCountable(): void + { + $metadata = new Metadata(); + self::assertCount(0, $metadata); + + $metadata->add('key', 'value'); + self::assertCount(1, $metadata); + + $metadata->add('key2', 'value2'); + self::assertCount(2, $metadata); + + $metadata->remove('key'); + self::assertCount(1, $metadata); + } +} diff --git a/src/platform/tests/Response/RawResponseAwareTraitTest.php b/src/platform/tests/Response/RawResponseAwareTraitTest.php new file mode 100644 index 000000000..03ad37df4 --- /dev/null +++ b/src/platform/tests/Response/RawResponseAwareTraitTest.php @@ -0,0 +1,56 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Response; + +use PHPUnit\Framework\Attributes\CoversTrait; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Response\Exception\RawResponseAlreadySetException; +use Symfony\AI\Platform\Response\RawResponseAwareTrait; +use Symfony\Contracts\HttpClient\ResponseInterface as SymfonyHttpResponse; + +#[CoversTrait(RawResponseAwareTrait::class)] +#[Small] +#[UsesClass(RawResponseAlreadySetException::class)] +final class RawResponseAwareTraitTest extends TestCase +{ + #[Test] + public function itCanBeEnrichedWithARawResponse(): void + { + $response = $this->createTestClass(); + $rawResponse = self::createMock(SymfonyHttpResponse::class); + + $response->setRawResponse($rawResponse); + self::assertSame($rawResponse, $response->getRawResponse()); + } + + #[Test] + public function itThrowsAnExceptionWhenSettingARawResponseTwice(): void + { + self::expectException(RawResponseAlreadySetException::class); + + $response = $this->createTestClass(); + $rawResponse = self::createMock(SymfonyHttpResponse::class); + + $response->setRawResponse($rawResponse); + $response->setRawResponse($rawResponse); + } + + private function createTestClass(): object + { + return new class { + use RawResponseAwareTrait; + }; + } +} diff --git a/src/platform/tests/Response/StreamResponseTest.php b/src/platform/tests/Response/StreamResponseTest.php new file mode 100644 index 000000000..3242f9749 --- /dev/null +++ b/src/platform/tests/Response/StreamResponseTest.php @@ -0,0 +1,41 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Response; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Response\StreamResponse; + +#[CoversClass(StreamResponse::class)] +#[Small] +final class StreamResponseTest extends TestCase +{ + #[Test] + public function getContent(): void + { + $generator = (function () { + yield 'data1'; + yield 'data2'; + })(); + + $response = new StreamResponse($generator); + self::assertInstanceOf(\Generator::class, $response->getContent()); + + $content = iterator_to_array($response->getContent()); + + self::assertCount(2, $content); + self::assertSame('data1', $content[0]); + self::assertSame('data2', $content[1]); + } +} diff --git a/src/platform/tests/Response/StructuredResponseTest.php b/src/platform/tests/Response/StructuredResponseTest.php new file mode 100644 index 000000000..69c1232f8 --- /dev/null +++ b/src/platform/tests/Response/StructuredResponseTest.php @@ -0,0 +1,37 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Response; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Response\ObjectResponse; + +#[CoversClass(ObjectResponse::class)] +#[Small] +final class StructuredResponseTest extends TestCase +{ + #[Test] + public function getContentWithArray(): void + { + $response = new ObjectResponse($expected = ['foo' => 'bar', 'baz' => ['qux']]); + self::assertSame($expected, $response->getContent()); + } + + #[Test] + public function getContentWithObject(): void + { + $response = new ObjectResponse($expected = (object) ['foo' => 'bar', 'baz' => ['qux']]); + self::assertSame($expected, $response->getContent()); + } +} diff --git a/src/platform/tests/Response/TextResponseTest.php b/src/platform/tests/Response/TextResponseTest.php new file mode 100644 index 000000000..9a2fd13c2 --- /dev/null +++ b/src/platform/tests/Response/TextResponseTest.php @@ -0,0 +1,30 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Response; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Response\TextResponse; + +#[CoversClass(TextResponse::class)] +#[Small] +final class TextResponseTest extends TestCase +{ + #[Test] + public function getContent(): void + { + $response = new TextResponse($expected = 'foo'); + self::assertSame($expected, $response->getContent()); + } +} diff --git a/src/platform/tests/Response/TollCallResponseTest.php b/src/platform/tests/Response/TollCallResponseTest.php new file mode 100644 index 000000000..8978f17ec --- /dev/null +++ b/src/platform/tests/Response/TollCallResponseTest.php @@ -0,0 +1,43 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Response; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Exception\InvalidArgumentException; +use Symfony\AI\Platform\Response\ToolCall; +use Symfony\AI\Platform\Response\ToolCallResponse; + +#[CoversClass(ToolCallResponse::class)] +#[UsesClass(ToolCall::class)] +#[Small] +final class TollCallResponseTest extends TestCase +{ + #[Test] + public function throwsIfNoToolCall(): void + { + self::expectException(InvalidArgumentException::class); + self::expectExceptionMessage('Response must have at least one tool call.'); + + new ToolCallResponse(); + } + + #[Test] + public function getContent(): void + { + $response = new ToolCallResponse($toolCall = new ToolCall('ID', 'name', ['foo' => 'bar'])); + self::assertSame([$toolCall], $response->getContent()); + } +} diff --git a/src/platform/tests/Response/ToolCallTest.php b/src/platform/tests/Response/ToolCallTest.php new file mode 100644 index 000000000..5737c6b85 --- /dev/null +++ b/src/platform/tests/Response/ToolCallTest.php @@ -0,0 +1,46 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Response; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Response\ToolCall; + +#[CoversClass(ToolCall::class)] +#[Small] +final class ToolCallTest extends TestCase +{ + #[Test] + public function toolCall(): void + { + $toolCall = new ToolCall('id', 'name', ['foo' => 'bar']); + self::assertSame('id', $toolCall->id); + self::assertSame('name', $toolCall->name); + self::assertSame(['foo' => 'bar'], $toolCall->arguments); + } + + #[Test] + public function toolCallJsonSerialize(): void + { + $toolCall = new ToolCall('id', 'name', ['foo' => 'bar']); + self::assertSame([ + 'id' => 'id', + 'type' => 'function', + 'function' => [ + 'name' => 'name', + 'arguments' => '{"foo":"bar"}', + ], + ], $toolCall->jsonSerialize()); + } +} diff --git a/src/store/.gitattributes b/src/store/.gitattributes new file mode 100644 index 000000000..ec8c01802 --- /dev/null +++ b/src/store/.gitattributes @@ -0,0 +1,6 @@ +/.github export-ignore +/tests export-ignore +.gitattributes export-ignore +.gitignore export-ignore +phpstan.dist.neon export-ignore +phpunit.xml.dist export-ignore diff --git a/src/store/.github/PULL_REQUEST_TEMPLATE.md b/src/store/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 000000000..fcb87228a --- /dev/null +++ b/src/store/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,8 @@ +Please do not submit any Pull Requests here. They will be closed. +--- + +Please submit your PR here instead: +https://github.com/symfony/ai + +This repository is what we call a "subtree split": a read-only subset of that main repository. +We're looking forward to your PR there! diff --git a/src/store/.github/workflows/close-pull-request.yml b/src/store/.github/workflows/close-pull-request.yml new file mode 100644 index 000000000..207153fd5 --- /dev/null +++ b/src/store/.github/workflows/close-pull-request.yml @@ -0,0 +1,20 @@ +name: Close Pull Request + +on: + pull_request_target: + types: [opened] + +jobs: + run: + runs-on: ubuntu-latest + steps: + - uses: superbrothers/close-pull-request@v3 + with: + comment: | + Thanks for your Pull Request! We love contributions. + + However, you should instead open your PR on the main repository: + https://github.com/symfony/ai + + This repository is what we call a "subtree split": a read-only subset of that main repository. + We're looking forward to your PR there! diff --git a/src/store/.gitignore b/src/store/.gitignore new file mode 100644 index 000000000..f43db636b --- /dev/null +++ b/src/store/.gitignore @@ -0,0 +1,3 @@ +composer.lock +vendor +.phpunit.cache diff --git a/src/store/LICENSE b/src/store/LICENSE new file mode 100644 index 000000000..bc38d714e --- /dev/null +++ b/src/store/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2025-present Fabien Potencier + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is furnished +to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/src/store/composer.json b/src/store/composer.json new file mode 100644 index 000000000..60b90fdbd --- /dev/null +++ b/src/store/composer.json @@ -0,0 +1,75 @@ +{ + "name": "symfony/ai-store", + "type": "library", + "description": "PHP library for abstracting interaction with data stores in AI applications.", + "keywords": [ + "ai", + "mongodb", + "pinecone", + "chromadb" + ], + "license": "MIT", + "authors": [ + { + "name": "Christopher Hertel", + "email": "mail@christopher-hertel.de" + }, + { + "name": "Oskar Stark", + "email": "oskarstark@googlemail.com" + } + ], + "require": { + "php": ">=8.2", + "ext-fileinfo": "*", + "oskarstark/enum-helper": "^1.5", + "phpdocumentor/reflection-docblock": "^5.4", + "phpstan/phpdoc-parser": "^2.1", + "psr/cache": "^3.0", + "psr/log": "^3.0", + "symfony/clock": "^6.4 || ^7.1", + "symfony/http-client": "^6.4 || ^7.1", + "symfony/property-access": "^6.4 || ^7.1", + "symfony/property-info": "^6.4 || ^7.1", + "symfony/serializer": "^6.4 || ^7.1", + "symfony/type-info": "^7.2.3", + "symfony/uid": "^6.4 || ^7.1", + "webmozart/assert": "^1.11" + }, + "conflict": { + "mongodb/mongodb": "<1.21" + }, + "require-dev": { + "codewithkyrian/chromadb-php": "^0.2.1 || ^0.3 || ^0.4", + "mongodb/mongodb": "^1.21", + "phpstan/phpstan": "^2.0", + "phpstan/phpstan-symfony": "^2.0", + "phpstan/phpstan-webmozart-assert": "^2.0", + "phpunit/phpunit": "^11.5", + "probots-io/pinecone-php": "^1.0", + "symfony/console": "^6.4 || ^7.1", + "symfony/dotenv": "^6.4 || ^7.1", + "symfony/event-dispatcher": "^6.4 || ^7.1", + "symfony/finder": "^6.4 || ^7.1", + "symfony/process": "^6.4 || ^7.1", + "symfony/var-dumper": "^6.4 || ^7.1" + }, + "suggest": { + "codewithkyrian/chromadb-php": "For using the ChromaDB as retrieval vector store.", + "mongodb/mongodb": "For using MongoDB Atlas as retrieval vector store.", + "probots-io/pinecone-php": "For using the Pinecone as retrieval vector store." + }, + "config": { + "sort-packages": true + }, + "autoload": { + "psr-4": { + "Symfony\\AI\\Store\\": "src/" + } + }, + "autoload-dev": { + "psr-4": { + "Symfony\\AI\\Store\\Tests\\": "tests/" + } + } +} diff --git a/src/store/phpstan.dist.neon b/src/store/phpstan.dist.neon new file mode 100644 index 000000000..8cc83f644 --- /dev/null +++ b/src/store/phpstan.dist.neon @@ -0,0 +1,10 @@ +includes: + - vendor/phpstan/phpstan-webmozart-assert/extension.neon + - vendor/phpstan/phpstan-symfony/extension.neon + +parameters: + level: 6 + paths: + - src/ + - tests/ + diff --git a/src/store/phpunit.xml.dist b/src/store/phpunit.xml.dist new file mode 100644 index 000000000..4e9e3a684 --- /dev/null +++ b/src/store/phpunit.xml.dist @@ -0,0 +1,24 @@ + + + + + tests + + + + + + src + + + diff --git a/src/store/src/Bridge/Azure/SearchStore.php b/src/store/src/Bridge/Azure/SearchStore.php new file mode 100644 index 000000000..a2b945e5e --- /dev/null +++ b/src/store/src/Bridge/Azure/SearchStore.php @@ -0,0 +1,121 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store\Bridge\Azure; + +use Symfony\AI\Platform\Vector\NullVector; +use Symfony\AI\Platform\Vector\Vector; +use Symfony\AI\Store\Document\Metadata; +use Symfony\AI\Store\Document\VectorDocument; +use Symfony\AI\Store\VectorStoreInterface; +use Symfony\Component\Uid\Uuid; +use Symfony\Contracts\HttpClient\HttpClientInterface; + +/** + * @author Christopher Hertel + */ +final readonly class SearchStore implements VectorStoreInterface +{ + /** + * @param string $vectorFieldName The name of the field int the index that contains the vector + */ + public function __construct( + private HttpClientInterface $httpClient, + private string $endpointUrl, + #[\SensitiveParameter] private string $apiKey, + private string $indexName, + private string $apiVersion, + private string $vectorFieldName = 'vector', + ) { + } + + public function add(VectorDocument ...$documents): void + { + $this->request('index', [ + 'value' => array_map([$this, 'convertToIndexableArray'], $documents), + ]); + } + + public function query(Vector $vector, array $options = [], ?float $minScore = null): array + { + $result = $this->request('search', [ + 'vectorQueries' => [$this->buildVectorQuery($vector)], + ]); + + return array_map([$this, 'convertToVectorDocument'], $result['value']); + } + + /** + * @param array $payload + * + * @return array + */ + private function request(string $endpoint, array $payload): array + { + $url = \sprintf('%s/indexes/%s/docs/%s', $this->endpointUrl, $this->indexName, $endpoint); + $response = $this->httpClient->request('POST', $url, [ + 'headers' => [ + 'api-key' => $this->apiKey, + ], + 'query' => ['api-version' => $this->apiVersion], + 'json' => $payload, + ]); + + return $response->toArray(); + } + + /** + * @return array + */ + private function convertToIndexableArray(VectorDocument $document): array + { + return array_merge([ + 'id' => $document->id, + $this->vectorFieldName => $document->vector->getData(), + ], $document->metadata->getArrayCopy()); + } + + /** + * @param array $data + */ + private function convertToVectorDocument(array $data): VectorDocument + { + return new VectorDocument( + id: Uuid::fromString($data['id']), + vector: !\array_key_exists($this->vectorFieldName, $data) || null === $data[$this->vectorFieldName] + ? new NullVector() + : new Vector($data[$this->vectorFieldName]), + metadata: new Metadata($data), + ); + } + + /** + * @return array{ + * kind: 'vector', + * vector: float[], + * exhaustive: true, + * fields: non-empty-string, + * weight: float, + * k: int, + * } + */ + private function buildVectorQuery(Vector $vector): array + { + return [ + 'kind' => 'vector', + 'vector' => $vector->getData(), + 'exhaustive' => true, + 'fields' => $this->vectorFieldName, + 'weight' => 0.5, + 'k' => 5, + ]; + } +} diff --git a/src/store/src/Bridge/ChromaDB/Store.php b/src/store/src/Bridge/ChromaDB/Store.php new file mode 100644 index 000000000..27dc1053f --- /dev/null +++ b/src/store/src/Bridge/ChromaDB/Store.php @@ -0,0 +1,66 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store\Bridge\ChromaDB; + +use Codewithkyrian\ChromaDB\Client; +use Symfony\AI\Platform\Vector\Vector; +use Symfony\AI\Store\Document\Metadata; +use Symfony\AI\Store\Document\VectorDocument; +use Symfony\AI\Store\VectorStoreInterface; +use Symfony\Component\Uid\Uuid; + +/** + * @author Christopher Hertel + */ +final readonly class Store implements VectorStoreInterface +{ + public function __construct( + private Client $client, + private string $collectionName, + ) { + } + + public function add(VectorDocument ...$documents): void + { + $ids = []; + $vectors = []; + $metadata = []; + foreach ($documents as $document) { + $ids[] = (string) $document->id; + $vectors[] = $document->vector->getData(); + $metadata[] = $document->metadata->getArrayCopy(); + } + + $collection = $this->client->getOrCreateCollection($this->collectionName); + $collection->add($ids, $vectors, $metadata); + } + + public function query(Vector $vector, array $options = [], ?float $minScore = null): array + { + $collection = $this->client->getOrCreateCollection($this->collectionName); + $queryResponse = $collection->query( + queryEmbeddings: [$vector->getData()], + nResults: 4, + ); + + $documents = []; + for ($i = 0; $i < \count($queryResponse->metadatas[0]); ++$i) { + $documents[] = new VectorDocument( + id: Uuid::fromString($queryResponse->ids[0][$i]), + vector: new Vector($queryResponse->embeddings[0][$i]), + metadata: new Metadata($queryResponse->metadatas[0][$i]), + ); + } + + return $documents; + } +} diff --git a/src/store/src/Bridge/MongoDB/Store.php b/src/store/src/Bridge/MongoDB/Store.php new file mode 100644 index 000000000..bf6b90098 --- /dev/null +++ b/src/store/src/Bridge/MongoDB/Store.php @@ -0,0 +1,196 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store\Bridge\MongoDB; + +use MongoDB\BSON\Binary; +use MongoDB\Client; +use MongoDB\Collection; +use MongoDB\Driver\Exception\CommandException; +use Psr\Log\LoggerInterface; +use Psr\Log\NullLogger; +use Symfony\AI\Platform\Vector\Vector; +use Symfony\AI\Store\Document\Metadata; +use Symfony\AI\Store\Document\VectorDocument; +use Symfony\AI\Store\Exception\InvalidArgumentException; +use Symfony\AI\Store\InitializableStoreInterface; +use Symfony\AI\Store\VectorStoreInterface; +use Symfony\Component\Uid\Uuid; + +/** + * @see https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-overview/ + * + * For this store you need to create a separate MongoDB Atlas Search index. + * The index needs to be created with the following settings: + * { + * "fields": [ + * { + * "numDimensions": 1536, + * "path": "vector", + * "similarity": "euclidean", + * "type": "vector" + * } + * ] + * } + * + * Note, that the `path` key needs to match the $vectorFieldName. + * + * For the `similarity` key you can choose between `euclidean`, `cosine` and `dotProduct`. + * {@see https://www.mongodb.com/docs/atlas/atlas-search/field-types/knn-vector/#define-the-index-for-the-fts-field-type-type} + * + * @author Oskar Stark + */ +final readonly class Store implements VectorStoreInterface, InitializableStoreInterface +{ + /** + * @param string $databaseName The name of the database + * @param string $collectionName The name of the collection + * @param string $indexName The name of the Atlas Search index + * @param string $vectorFieldName The name of the field int the index that contains the vector + * @param bool $bulkWrite Use bulk write operations + */ + public function __construct( + private Client $client, + private string $databaseName, + private string $collectionName, + private string $indexName, + private string $vectorFieldName = 'vector', + private bool $bulkWrite = false, + private LoggerInterface $logger = new NullLogger(), + ) { + } + + public function add(VectorDocument ...$documents): void + { + $operations = []; + + foreach ($documents as $document) { + $operation = [ + ['_id' => $this->toBinary($document->id)], // we use binary for the id, because of storage efficiency + array_filter([ + 'metadata' => $document->metadata->getArrayCopy(), + $this->vectorFieldName => $document->vector->getData(), + ]), + ['upsert' => true], // insert if not exists + ]; + + if ($this->bulkWrite) { + $operations[] = ['replaceOne' => $operation]; + continue; + } + + $this->getCollection()->replaceOne(...$operation); + } + + if ($this->bulkWrite) { + $this->getCollection()->bulkWrite($operations); + } + } + + /** + * @param array{ + * limit?: positive-int, + * numCandidates?: positive-int, + * filter?: array + * } $options + */ + public function query(Vector $vector, array $options = [], ?float $minScore = null): array + { + $pipeline = [ + [ + '$vectorSearch' => array_merge([ + 'index' => $this->indexName, + 'path' => $this->vectorFieldName, + 'queryVector' => $vector->getData(), + 'numCandidates' => 200, + 'limit' => 5, + ], $options), + ], + [ + '$addFields' => [ + 'score' => ['$meta' => 'vectorSearchScore'], + ], + ], + ]; + + if (null !== $minScore) { + $pipeline[] = [ + '$match' => [ + 'score' => ['$gte' => $minScore], + ], + ]; + } + + $results = $this->getCollection()->aggregate( + $pipeline, + ['typeMap' => ['root' => 'array', 'document' => 'array', 'array' => 'array']] + ); + + $documents = []; + + foreach ($results as $result) { + $documents[] = new VectorDocument( + id: $this->toUuid($result['_id']), + vector: new Vector($result[$this->vectorFieldName]), + metadata: new Metadata($result['metadata'] ?? []), + score: $result['score'], + ); + } + + return $documents; + } + + /** + * @param array{fields?: array} $options + */ + public function initialize(array $options = []): void + { + if ([] !== $options && !\array_key_exists('fields', $options)) { + throw new InvalidArgumentException('The only supported option is "fields"'); + } + + try { + $this->getCollection()->createSearchIndex( + [ + 'fields' => array_merge([ + [ + 'numDimensions' => 1536, + 'path' => $this->vectorFieldName, + 'similarity' => 'euclidean', + 'type' => 'vector', + ], + ], $options['fields'] ?? []), + ], + [ + 'name' => $this->indexName, + 'type' => 'vectorSearch', + ], + ); + } catch (CommandException $e) { + $this->logger->warning($e->getMessage()); + } + } + + private function getCollection(): Collection + { + return $this->client->getCollection($this->databaseName, $this->collectionName); + } + + private function toBinary(Uuid $uuid): Binary + { + return new Binary($uuid->toBinary(), Binary::TYPE_UUID); + } + + private function toUuid(Binary $binary): Uuid + { + return Uuid::fromString($binary->getData()); + } +} diff --git a/src/store/src/Bridge/Pinecone/Store.php b/src/store/src/Bridge/Pinecone/Store.php new file mode 100644 index 000000000..f17923148 --- /dev/null +++ b/src/store/src/Bridge/Pinecone/Store.php @@ -0,0 +1,83 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store\Bridge\Pinecone; + +use Probots\Pinecone\Client; +use Probots\Pinecone\Resources\Data\VectorResource; +use Symfony\AI\Platform\Vector\Vector; +use Symfony\AI\Store\Document\Metadata; +use Symfony\AI\Store\Document\VectorDocument; +use Symfony\AI\Store\VectorStoreInterface; +use Symfony\Component\Uid\Uuid; + +/** + * @author Christopher Hertel + */ +final readonly class Store implements VectorStoreInterface +{ + /** + * @param array $filter + */ + public function __construct( + private Client $pinecone, + private ?string $namespace = null, + private array $filter = [], + private int $topK = 3, + ) { + } + + public function add(VectorDocument ...$documents): void + { + $vectors = []; + foreach ($documents as $document) { + $vectors[] = [ + 'id' => (string) $document->id, + 'values' => $document->vector->getData(), + 'metadata' => $document->metadata->getArrayCopy(), + ]; + } + + if ([] === $vectors) { + return; + } + + $this->getVectors()->upsert($vectors, $this->namespace); + } + + public function query(Vector $vector, array $options = [], ?float $minScore = null): array + { + $response = $this->getVectors()->query( + vector: $vector->getData(), + namespace: $options['namespace'] ?? $this->namespace, + filter: $options['filter'] ?? $this->filter, + topK: $options['topK'] ?? $this->topK, + includeValues: true, + ); + + $documents = []; + foreach ($response->json()['matches'] as $match) { + $documents[] = new VectorDocument( + id: Uuid::fromString($match['id']), + vector: new Vector($match['values']), + metadata: new Metadata($match['metadata']), + score: $match['score'], + ); + } + + return $documents; + } + + private function getVectors(): VectorResource + { + return $this->pinecone->data()->vectors(); + } +} diff --git a/src/store/src/Document/Metadata.php b/src/store/src/Document/Metadata.php new file mode 100644 index 000000000..5ce7c105c --- /dev/null +++ b/src/store/src/Document/Metadata.php @@ -0,0 +1,21 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store\Document; + +/** + * @template-extends \ArrayObject + * + * @author Christopher Hertel + */ +final class Metadata extends \ArrayObject +{ +} diff --git a/src/store/src/Document/TextDocument.php b/src/store/src/Document/TextDocument.php new file mode 100644 index 000000000..bfb4f7f20 --- /dev/null +++ b/src/store/src/Document/TextDocument.php @@ -0,0 +1,29 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store\Document; + +use Symfony\Component\Uid\Uuid; +use Webmozart\Assert\Assert; + +/** + * @author Christopher Hertel + */ +final readonly class TextDocument +{ + public function __construct( + public Uuid $id, + public string $content, + public Metadata $metadata = new Metadata(), + ) { + Assert::stringNotEmpty(trim($this->content)); + } +} diff --git a/src/store/src/Document/VectorDocument.php b/src/store/src/Document/VectorDocument.php new file mode 100644 index 000000000..294c94521 --- /dev/null +++ b/src/store/src/Document/VectorDocument.php @@ -0,0 +1,29 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store\Document; + +use Symfony\AI\Platform\Vector\VectorInterface; +use Symfony\Component\Uid\Uuid; + +/** + * @author Christopher Hertel + */ +final readonly class VectorDocument +{ + public function __construct( + public Uuid $id, + public VectorInterface $vector, + public Metadata $metadata = new Metadata(), + public ?float $score = null, + ) { + } +} diff --git a/src/store/src/Embedder.php b/src/store/src/Embedder.php new file mode 100644 index 000000000..d19b3243a --- /dev/null +++ b/src/store/src/Embedder.php @@ -0,0 +1,97 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store; + +use Psr\Log\LoggerInterface; +use Psr\Log\NullLogger; +use Symfony\AI\Platform\Capability; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\PlatformInterface; +use Symfony\AI\Store\Document\TextDocument; +use Symfony\AI\Store\Document\VectorDocument; +use Symfony\Component\Clock\Clock; +use Symfony\Component\Clock\ClockInterface; + +/** + * @author Christopher Hertel + */ +final readonly class Embedder +{ + private ClockInterface $clock; + + public function __construct( + private PlatformInterface $platform, + private Model $model, + private StoreInterface $store, + ?ClockInterface $clock = null, + private LoggerInterface $logger = new NullLogger(), + ) { + $this->clock = $clock ?? Clock::get(); + } + + /** + * @param TextDocument|TextDocument[] $documents + */ + public function embed(TextDocument|array $documents, int $chunkSize = 0, int $sleep = 0): void + { + if ($documents instanceof TextDocument) { + $documents = [$documents]; + } + + if ([] === $documents) { + $this->logger->debug('No documents to embed'); + + return; + } + + $chunks = 0 !== $chunkSize ? array_chunk($documents, $chunkSize) : [$documents]; + + foreach ($chunks as $chunk) { + $this->store->add(...$this->createVectorDocuments($chunk)); + + if (0 !== $sleep) { + $this->clock->sleep($sleep); + } + } + } + + /** + * @param TextDocument[] $documents + * + * @return VectorDocument[] + */ + private function createVectorDocuments(array $documents): array + { + if ($this->model->supports(Capability::INPUT_MULTIPLE)) { + $response = $this->platform->request($this->model, array_map(fn (TextDocument $document) => $document->content, $documents)); + + $vectors = $response->getContent(); + } else { + $responses = []; + foreach ($documents as $document) { + $responses[] = $this->platform->request($this->model, $document->content); + } + + $vectors = []; + foreach ($responses as $response) { + $vectors = array_merge($vectors, $response->getContent()); + } + } + + $vectorDocuments = []; + foreach ($documents as $i => $document) { + $vectorDocuments[] = new VectorDocument($document->id, $vectors[$i], $document->metadata); + } + + return $vectorDocuments; + } +} diff --git a/src/store/src/Exception/ExceptionInterface.php b/src/store/src/Exception/ExceptionInterface.php new file mode 100644 index 000000000..918a9005e --- /dev/null +++ b/src/store/src/Exception/ExceptionInterface.php @@ -0,0 +1,19 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store\Exception; + +/** + * @author Oskar Stark + */ +interface ExceptionInterface extends \Throwable +{ +} diff --git a/src/store/src/Exception/InvalidArgumentException.php b/src/store/src/Exception/InvalidArgumentException.php new file mode 100644 index 000000000..82cbefdd4 --- /dev/null +++ b/src/store/src/Exception/InvalidArgumentException.php @@ -0,0 +1,19 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store\Exception; + +/** + * @author Oskar Stark + */ +class InvalidArgumentException extends \InvalidArgumentException implements ExceptionInterface +{ +} diff --git a/src/store/src/Exception/RuntimeException.php b/src/store/src/Exception/RuntimeException.php new file mode 100644 index 000000000..6cd47c742 --- /dev/null +++ b/src/store/src/Exception/RuntimeException.php @@ -0,0 +1,19 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store\Exception; + +/** + * @author Oskar Stark + */ +class RuntimeException extends \RuntimeException implements ExceptionInterface +{ +} diff --git a/src/store/src/InitializableStoreInterface.php b/src/store/src/InitializableStoreInterface.php new file mode 100644 index 000000000..f5c62837a --- /dev/null +++ b/src/store/src/InitializableStoreInterface.php @@ -0,0 +1,23 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store; + +/** + * @author Oskar Stark + */ +interface InitializableStoreInterface extends StoreInterface +{ + /** + * @param array $options + */ + public function initialize(array $options = []): void; +} diff --git a/src/store/src/StoreInterface.php b/src/store/src/StoreInterface.php new file mode 100644 index 000000000..65e585af7 --- /dev/null +++ b/src/store/src/StoreInterface.php @@ -0,0 +1,22 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store; + +use Symfony\AI\Store\Document\VectorDocument; + +/** + * @author Christopher Hertel + */ +interface StoreInterface +{ + public function add(VectorDocument ...$documents): void; +} diff --git a/src/store/src/VectorStoreInterface.php b/src/store/src/VectorStoreInterface.php new file mode 100644 index 000000000..2df2ade97 --- /dev/null +++ b/src/store/src/VectorStoreInterface.php @@ -0,0 +1,28 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store; + +use Symfony\AI\Platform\Vector\Vector; +use Symfony\AI\Store\Document\VectorDocument; + +/** + * @author Christopher Hertel + */ +interface VectorStoreInterface extends StoreInterface +{ + /** + * @param array $options + * + * @return VectorDocument[] + */ + public function query(Vector $vector, array $options = [], ?float $minScore = null): array; +} diff --git a/src/store/tests/Document/NullVectorTest.php b/src/store/tests/Document/NullVectorTest.php new file mode 100644 index 000000000..8423f31f0 --- /dev/null +++ b/src/store/tests/Document/NullVectorTest.php @@ -0,0 +1,45 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store\Tests\Document; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Vector\NullVector; +use Symfony\AI\Platform\Vector\VectorInterface; +use Symfony\AI\Store\Exception\RuntimeException; + +#[CoversClass(NullVector::class)] +final class NullVectorTest extends TestCase +{ + #[Test] + public function implementsInterface(): void + { + self::assertInstanceOf(VectorInterface::class, new NullVector()); + } + + #[Test] + public function getDataThrowsOnAccess(): void + { + self::expectException(RuntimeException::class); + + (new NullVector())->getData(); + } + + #[Test] + public function getDimensionsThrowsOnAccess(): void + { + self::expectException(RuntimeException::class); + + (new NullVector())->getDimensions(); + } +} diff --git a/src/store/tests/Document/VectorTest.php b/src/store/tests/Document/VectorTest.php new file mode 100644 index 000000000..d14b57970 --- /dev/null +++ b/src/store/tests/Document/VectorTest.php @@ -0,0 +1,40 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store\Tests\Document; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Vector\Vector; +use Symfony\AI\Platform\Vector\VectorInterface; + +#[CoversClass(Vector::class)] +final class VectorTest extends TestCase +{ + #[Test] + public function implementsInterface(): void + { + self::assertInstanceOf( + VectorInterface::class, + new Vector([1.0, 2.0, 3.0]) + ); + } + + #[Test] + public function withDimensionNull(): void + { + $vector = new Vector($vectors = [1.0, 2.0, 3.0], null); + + self::assertSame($vectors, $vector->getData()); + self::assertSame(3, $vector->getDimensions()); + } +} diff --git a/src/store/tests/Double/PlatformTestHandler.php b/src/store/tests/Double/PlatformTestHandler.php new file mode 100644 index 000000000..7c616421a --- /dev/null +++ b/src/store/tests/Double/PlatformTestHandler.php @@ -0,0 +1,57 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store\Tests\Double; + +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\ModelClientInterface; +use Symfony\AI\Platform\Platform; +use Symfony\AI\Platform\Response\ResponseInterface; +use Symfony\AI\Platform\Response\ResponseInterface as LlmResponse; +use Symfony\AI\Platform\Response\VectorResponse; +use Symfony\AI\Platform\ResponseConverterInterface; +use Symfony\AI\Platform\Vector\Vector; +use Symfony\Component\HttpClient\Response\MockResponse; +use Symfony\Contracts\HttpClient\ResponseInterface as HttpResponse; + +final class PlatformTestHandler implements ModelClientInterface, ResponseConverterInterface +{ + public int $createCalls = 0; + + public function __construct( + private readonly ?ResponseInterface $create = null, + ) { + } + + public static function createPlatform(?ResponseInterface $create = null): Platform + { + $handler = new self($create); + + return new Platform([$handler], [$handler]); + } + + public function supports(Model $model): bool + { + return true; + } + + public function request(Model $model, array|string|object $payload, array $options = []): HttpResponse + { + ++$this->createCalls; + + return new MockResponse(); + } + + public function convert(HttpResponse $response, array $options = []): LlmResponse + { + return $this->create ?? new VectorResponse(new Vector([1, 2, 3])); + } +} diff --git a/src/store/tests/Double/TestStore.php b/src/store/tests/Double/TestStore.php new file mode 100644 index 000000000..2bfe409a3 --- /dev/null +++ b/src/store/tests/Double/TestStore.php @@ -0,0 +1,31 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store\Tests\Double; + +use Symfony\AI\Store\Document\VectorDocument; +use Symfony\AI\Store\StoreInterface; + +final class TestStore implements StoreInterface +{ + /** + * @var VectorDocument[] + */ + public array $documents = []; + + public int $addCalls = 0; + + public function add(VectorDocument ...$documents): void + { + ++$this->addCalls; + $this->documents = array_merge($this->documents, $documents); + } +} diff --git a/src/store/tests/EmbedderTest.php b/src/store/tests/EmbedderTest.php new file mode 100644 index 000000000..d41463b77 --- /dev/null +++ b/src/store/tests/EmbedderTest.php @@ -0,0 +1,138 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store\Tests; + +use PhpLlm\LlmChain\Tests\Double\PlatformTestHandler; +use PhpLlm\LlmChain\Tests\Double\TestStore; +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Medium; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Psr\Log\LoggerInterface; +use Symfony\AI\Platform\Bridge\OpenAI\Embeddings; +use Symfony\AI\Platform\Message\ToolCallMessage; +use Symfony\AI\Platform\Platform; +use Symfony\AI\Platform\Response\AsyncResponse; +use Symfony\AI\Platform\Response\ToolCall; +use Symfony\AI\Platform\Response\VectorResponse; +use Symfony\AI\Platform\Vector\Vector; +use Symfony\AI\Store\Document\Metadata; +use Symfony\AI\Store\Document\TextDocument; +use Symfony\AI\Store\Document\VectorDocument; +use Symfony\AI\Store\Embedder; +use Symfony\Component\Clock\MockClock; +use Symfony\Component\Uid\Uuid; + +#[CoversClass(Embedder::class)] +#[Medium] +#[UsesClass(TextDocument::class)] +#[UsesClass(Vector::class)] +#[UsesClass(VectorDocument::class)] +#[UsesClass(ToolCallMessage::class)] +#[UsesClass(ToolCall::class)] +#[UsesClass(Embeddings::class)] +#[UsesClass(Platform::class)] +#[UsesClass(AsyncResponse::class)] +#[UsesClass(VectorResponse::class)] +final class EmbedderTest extends TestCase +{ + #[Test] + public function embedSingleDocument(): void + { + $document = new TextDocument($id = Uuid::v4(), 'Test content'); + $vector = new Vector([0.1, 0.2, 0.3]); + + $embedder = new Embedder( + PlatformTestHandler::createPlatform(new VectorResponse($vector)), + new Embeddings(), + $store = new TestStore(), + new MockClock(), + ); + + $embedder->embed($document); + + self::assertCount(1, $store->documents); + self::assertInstanceOf(VectorDocument::class, $store->documents[0]); + self::assertSame($id, $store->documents[0]->id); + self::assertSame($vector, $store->documents[0]->vector); + } + + #[Test] + public function embedEmptyDocumentList(): void + { + $logger = self::createMock(LoggerInterface::class); + $logger->expects(self::once())->method('debug')->with('No documents to embed'); + + $embedder = new Embedder( + PlatformTestHandler::createPlatform(), + new Embeddings(), + $store = new TestStore(), + new MockClock(), + $logger, + ); + + $embedder->embed([]); + + self::assertSame([], $store->documents); + } + + #[Test] + public function embedDocumentWithMetadata(): void + { + $metadata = new Metadata(['key' => 'value']); + $document = new TextDocument($id = Uuid::v4(), 'Test content', $metadata); + $vector = new Vector([0.1, 0.2, 0.3]); + + $embedder = new Embedder( + PlatformTestHandler::createPlatform(new VectorResponse($vector)), + new Embeddings(), + $store = new TestStore(), + new MockClock(), + ); + + $embedder->embed($document); + + self::assertSame(1, $store->addCalls); + self::assertCount(1, $store->documents); + self::assertInstanceOf(VectorDocument::class, $store->documents[0]); + self::assertSame($id, $store->documents[0]->id); + self::assertSame($vector, $store->documents[0]->vector); + self::assertSame(['key' => 'value'], $store->documents[0]->metadata->getArrayCopy()); + } + + #[Test] + public function embedWithSleep(): void + { + $vector1 = new Vector([0.1, 0.2, 0.3]); + $vector2 = new Vector([0.4, 0.5, 0.6]); + + $document1 = new TextDocument(Uuid::v4(), 'Test content 1'); + $document2 = new TextDocument(Uuid::v4(), 'Test content 2'); + + $embedder = new Embedder( + PlatformTestHandler::createPlatform(new VectorResponse($vector1, $vector2)), + new Embeddings(), + $store = new TestStore(), + $clock = new MockClock('2024-01-01 00:00:00'), + ); + + $embedder->embed( + documents: [$document1, $document2], + sleep: 3 + ); + + self::assertSame(1, $store->addCalls); + self::assertCount(2, $store->documents); + self::assertSame('2024-01-01 00:00:03', $clock->now()->format('Y-m-d H:i:s')); + } +}