diff --git a/src/Chain/InputProcessor/SystemPromptInputProcessor.php b/src/Chain/InputProcessor/SystemPromptInputProcessor.php index 59a00a08..d8c8f979 100644 --- a/src/Chain/InputProcessor/SystemPromptInputProcessor.php +++ b/src/Chain/InputProcessor/SystemPromptInputProcessor.php @@ -15,11 +15,11 @@ final readonly class SystemPromptInputProcessor implements InputProcessor { /** - * @param string $systemPrompt the system prompt to prepend to the input messages + * @param \Stringable|string $systemPrompt the system prompt to prepend to the input messages * @param ToolBoxInterface|null $toolBox the tool box to be used to append the tool definitions to the system prompt */ public function __construct( - private string $systemPrompt, + private \Stringable|string $systemPrompt, private ?ToolBoxInterface $toolBox = null, private LoggerInterface $logger = new NullLogger(), ) { @@ -35,7 +35,7 @@ public function processInput(Input $input): void return; } - $message = $this->systemPrompt; + $message = (string) $this->systemPrompt; if ($this->toolBox instanceof ToolBoxInterface && [] !== $this->toolBox->getMap() diff --git a/tests/Chain/InputProcessor/SystemPromptInputProcessorTest.php b/tests/Chain/InputProcessor/SystemPromptInputProcessorTest.php index 20004a45..b7bd3a15 100644 --- a/tests/Chain/InputProcessor/SystemPromptInputProcessorTest.php +++ b/tests/Chain/InputProcessor/SystemPromptInputProcessorTest.php @@ -147,4 +147,41 @@ public function execute(ToolCall $toolCall): mixed or not PROMPT, $messages[0]->content); } + + #[Test] + public function withStringableSystemPrompt(): void + { + $processor = new SystemPromptInputProcessor( + new SystemPromptService(), + new class implements ToolBoxInterface { + public function getMap(): array + { + return [ + new Metadata(ToolNoParams::class, 'tool_no_params', 'A tool without parameters', '__invoke', null), + ]; + } + + public function execute(ToolCall $toolCall): mixed + { + return null; + } + } + ); + + $input = new Input(new GPT(), new MessageBag(Message::ofUser('This is a user message')), []); + $processor->processInput($input); + + $messages = $input->messages->getMessages(); + self::assertCount(2, $messages); + self::assertInstanceOf(SystemMessage::class, $messages[0]); + self::assertInstanceOf(UserMessage::class, $messages[1]); + self::assertSame(<<content); + } } diff --git a/tests/Chain/InputProcessor/SystemPromptService.php b/tests/Chain/InputProcessor/SystemPromptService.php new file mode 100644 index 00000000..75449d70 --- /dev/null +++ b/tests/Chain/InputProcessor/SystemPromptService.php @@ -0,0 +1,13 @@ +