Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Adding multi-turn promptSendingOrchestrator #317

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 0 additions & 27 deletions pyrit/common/net_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,33 +20,6 @@ def get_httpx_client(use_async: bool = False, debug: bool = False):
PostType = Literal["json", "data"]


@retry(stop=stop_after_attempt(2), wait=wait_fixed(1))
def make_request_and_raise_if_error(
endpoint_uri: str,
method: str,
request_body: dict[str, object] = None,
headers: dict[str, str] = None,
post_type: PostType = "json",
debug: bool = False,
) -> httpx.Response:
"""Make a request and raise an exception if it fails."""
headers = headers or {}
request_body = request_body or {}

with get_httpx_client(debug=debug) as client:
if request_body:
if post_type == "json":
response = client.request(method=method, url=endpoint_uri, json=request_body, headers=headers)
else:
response = client.request(method=method, url=endpoint_uri, data=request_body, headers=headers)
else:
response = client.request(method=method, url=endpoint_uri, headers=headers)

response.raise_for_status() # This will automatically raise an exception for 4xx and 5xx responses

return response


@retry(stop=stop_after_attempt(2), wait=wait_fixed(1), reraise=True)
async def make_request_and_raise_if_error_async(
endpoint_uri: str,
Expand Down
58 changes: 52 additions & 6 deletions pyrit/orchestrator/prompt_sending_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,18 @@ async def send_prompts_async(
memory_labels: Optional[dict[str, str]] = None,
) -> list[PromptRequestResponse]:
"""
Sends the prompts to the prompt target, updating global memory labels with any new labels provided by the user.
Sends the prompts to the prompt target.

Args:
prompt_list (list[str]): The list of prompts to be sent.
prompt_type (PromptDataType): The type of prompt data. Defaults to "text".
memory_labels (dict[str, str], optional): A free-form dictionary of labels to apply to the prompts.
memory_labels (dict[str, str], optional): A free-form dictionary of additional labels to apply to the
prompts.
rlundeen2 marked this conversation as resolved.
Show resolved Hide resolved
These labels will be merged with the instance's global memory labels. Defaults to None.

Returns:
list[PromptRequestResponse]: The responses from sending the prompts.
"""
if memory_labels:
self._global_memory_labels.update(memory_labels)

requests: list[NormalizerRequest] = []
for prompt in prompt_list:
Expand All @@ -99,10 +98,50 @@ async def send_prompts_async(

return await self.send_normalizer_requests_async(
prompt_request_list=requests,
memory_labels=memory_labels,
)

async def send_prompt_async(
self,
*,
prompt: str,
prompt_type: PromptDataType = "text",
memory_labels: Optional[dict[str, str]] = None,
conversation_id: Optional[str] = None,
) -> PromptRequestResponse:
"""
Sends a single prompts to the prompt target. Can be used for multi-turn using conversation_id.

Args:
prompt (list[str]): The prompt to be sent.
prompt_type (PromptDataType): The type of prompt data. Defaults to "text".
memory_labels (dict[str, str], optional): A free-form dictionary of extra labels to apply to the prompts.
These labels will be merged with the instance's global memory labels. Defaults to None.
conversation_id (str, optional): The conversation ID to use for multi-turn conversation. Defaults to None.

Returns:
list[PromptRequestResponse]: The responses from sending the prompts.
"""

normalizer_request = self._create_normalizer_request(
prompt_text=prompt,
prompt_type=prompt_type,
converters=self._prompt_converters,
)

return await self._prompt_normalizer.send_prompt_async(
normalizer_request=normalizer_request,
target=self._prompt_target,
conversation_id=conversation_id,
labels=self._combine_with_global_memory_labels(memory_labels),
orchestrator_identifier=self.get_identifier(),
)

async def send_normalizer_requests_async(
self, *, prompt_request_list: list[NormalizerRequest]
self,
*,
prompt_request_list: list[NormalizerRequest],
memory_labels: Optional[dict[str, str]] = None,
) -> list[PromptRequestResponse]:
"""
Sends the normalized prompts to the prompt target.
Expand All @@ -113,7 +152,7 @@ async def send_normalizer_requests_async(
responses: list[PromptRequestResponse] = await self._prompt_normalizer.send_prompt_batch_to_target_async(
requests=prompt_request_list,
target=self._prompt_target,
labels=self._global_memory_labels,
labels=self._combine_with_global_memory_labels(memory_labels),
orchestrator_identifier=self.get_identifier(),
batch_size=self._batch_size,
)
Expand Down Expand Up @@ -170,3 +209,10 @@ def print_conversations(self):
scores = self._memory.get_scores_by_prompt_ids(prompt_request_response_ids=[message.id])
for score in scores:
print(f"{Style.RESET_ALL}score: {score} : {score.score_rationale}")

def _combine_with_global_memory_labels(self, memory_labels: dict[str, str]) -> dict[str, str]:
"""
Combines the global memory labels with the provided memory labels.
The passed memory_leabels take prcedence with collisions.
"""
return {**(self._global_memory_labels or {}), **(memory_labels or {})}
14 changes: 14 additions & 0 deletions tests/orchestrator/test_prompt_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,20 @@ async def test_orchestrator_send_prompts_async_with_memory_labels(mock_target: M
assert entries[0].labels == expected_labels


@pytest.mark.asyncio
async def test_orchestrator_send_prompts_async_with_memory_labels_collision(mock_target: MockPromptTarget):
labels = {"op_name": "op1"}
orchestrator = PromptSendingOrchestrator(prompt_target=mock_target, memory_labels=labels)
new_labels = {"op_name": "op2"}
await orchestrator.send_prompts_async(prompt_list=["hello"], memory_labels=new_labels)
rlundeen2 marked this conversation as resolved.
Show resolved Hide resolved
assert mock_target.prompt_sent == ["hello"]

expected_labels = {"op_name": "op2"}
entries = orchestrator.get_memory()
assert len(entries) == 2
assert entries[0].labels == expected_labels


@pytest.mark.asyncio
async def test_orchestrator_get_score_memory(mock_target: MockPromptTarget):
scorer = AsyncMock()
Expand Down
33 changes: 17 additions & 16 deletions tests/test_common_net_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
import respx

from unittest.mock import patch, MagicMock
from tenacity import RetryError
from pyrit.common.net_utility import get_httpx_client, make_request_and_raise_if_error
from pyrit.common.net_utility import get_httpx_client, make_request_and_raise_if_error_async


@pytest.mark.parametrize(
Expand All @@ -23,32 +22,33 @@ def test_get_httpx_client_type(use_async, expected_type):


@respx.mock
def test_make_request_and_raise_if_error_success():
@pytest.mark.asyncio
async def test_make_request_and_raise_if_error_success():
url = "http://testserver/api/test"
method = "GET"
mock_route = respx.get(url).respond(200, json={"status": "ok"})
response = make_request_and_raise_if_error(endpoint_uri=url, method=method)
response = await make_request_and_raise_if_error_async(endpoint_uri=url, method=method)
assert mock_route.called
assert response.status_code == 200
assert response.json() == {"status": "ok"}


@respx.mock
def test_make_request_and_raise_if_error_failure():
@pytest.mark.asyncio
async def test_make_request_and_raise_if_error_failure():
url = "http://testserver/api/fail"
method = "GET"
mock_route = respx.get(url).respond(500)

with pytest.raises(RetryError) as retry_error:
make_request_and_raise_if_error(endpoint_uri=url, method=method)
with pytest.raises(httpx.HTTPStatusError):
await make_request_and_raise_if_error_async(endpoint_uri=url, method=method)
assert mock_route.called

last_exception = retry_error.value.last_attempt.exception()
assert isinstance(last_exception, httpx.HTTPStatusError)
assert len(mock_route.calls) == 2


@respx.mock
def test_make_request_and_raise_if_error_retries():
@pytest.mark.asyncio
async def test_make_request_and_raise_if_error_retries():
url = "http://testserver/api/retry"
method = "GET"
call_count = 0
Expand All @@ -62,17 +62,18 @@ def response_callback(request):

mock_route = respx.route(method=method, url=url).mock(side_effect=response_callback)

with pytest.raises(RetryError):
make_request_and_raise_if_error(endpoint_uri=url, method=method)
with pytest.raises(httpx.HTTPStatusError):
await make_request_and_raise_if_error_async(endpoint_uri=url, method=method)
assert call_count == 2, "The request should have been retried exactly once."
assert mock_route.called


def test_debug_is_false_by_default():
@pytest.mark.asyncio
async def test_debug_is_false_by_default():
with patch("pyrit.common.net_utility.get_httpx_client") as mock_get_httpx_client:
mock_client_instance = MagicMock()
mock_get_httpx_client.return_value = mock_client_instance

make_request_and_raise_if_error(endpoint_uri="http://example.com", method="GET")
await make_request_and_raise_if_error_async(endpoint_uri="http://example.com", method="GET")

mock_get_httpx_client.assert_called_with(debug=False)
mock_get_httpx_client.assert_called_with(debug=False, use_async=True)
Loading