Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

build: update versions of pre-commit hooks #47

Merged
merged 5 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,28 @@ repos:
- id: check-added-large-files
- id: detect-private-key

- repo: https://github.com/psf/black
rev: 23.7.0
# https://black.readthedocs.io/en/stable/integrations/source_version_control.html
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.2.0
hooks:
- id: black
language_version: python3.10

- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
rev: 7.0.0
hooks:
- id: flake8

- repo: https://github.com/pycqa/pylint
rev: v2.16.2
rev: v3.0.3
hooks:
- id: pylint
args: [ --disable=all, --enable=unused-import ]
args: [--disable=all, --enable=unused-import]
exclude: NOTICE.txt

# Calling local version of mypy rather the one from GitHub https://github.com/pre-commit/mirrors-mypy
# because there is no way to sync the configs in pyproject.toml and the .pre-commit-config.yaml
# See https://stackoverflow.com/a/75003826/4956355
- repo: local
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.8.0
hooks:
- id: mypy
name: mypy
entry: mypy
language: system
types: [ python ]
args: [--install-types, --non-interactive, --ignore-missing-imports]
26 changes: 13 additions & 13 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,25 @@ name = "pyrit"
version = "0.1.0"
description = "The Python Risk Identification Tool for LLMs (PyRIT) is a library used to assess the robustness of LLMs"
authors = [
{name = "Microsoft AI Red Team", email = "airedteam@microsoft.com"},
{name = "dlmgary"},
{name = "amandajean119"},
{name = "microsiska"},
{name = "rdheekonda"},
{name = "rlundeen2"},
{name = "romanlutz"},
{name = "jbolor21"},
{name = "nina-msft"},
{ name = "Microsoft AI Red Team", email = "airedteam@microsoft.com" },
{ name = "dlmgary" },
{ name = "amandajean119" },
{ name = "microsiska" },
{ name = "rdheekonda" },
{ name = "rlundeen2" },
{ name = "romanlutz" },
{ name = "jbolor21" },
{ name = "nina-msft" },
]
readme = "README.md"
license = {text = "MIT"}
license = { text = "MIT" }
keywords = [
"llm",
"ai-security",
"ai-red-team",
"ai-robustness",
"ai-robustness-testing",
"ai-risk-assessment"
"ai-risk-assessment",
]
classifiers = [
"Development Status :: 3 - Alpha",
Expand All @@ -48,15 +48,15 @@ dependencies = [
"tokenizers>=0.15.0",
"torch==2.1.2",
"transformers>=4.36.0",
"types-requests>=2.31.0.2"
"types-requests>=2.31.0.2",
]

[project.optional-dependencies]
dev = [
"black>=23.3.0",
"flake8>=6.0.0",
"flake8-copyright>=0.2.0",
"mypy>=1.2.0",
"mypy>=1.8.0",
"pre-commit>=3.3.3",
"pytest>=7.3.1",
"pytest-cov>=4.0.0",
Expand Down
8 changes: 3 additions & 5 deletions pyrit/common/print.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,16 @@
import textwrap

import termcolor

from pyrit.memory.memory_models import ConversationMemoryEntry
from pyrit.models import ChatMessage

_COLOR_TYPE = termcolor._types.Color
from termcolor._types import Color


def print_chat_messages_with_color(
messages: list[ChatMessage | ConversationMemoryEntry],
max_content_character_width: int = 80,
left_padding_width: int = 20,
custom_colors: dict[str, _COLOR_TYPE] = None,
custom_colors: dict[str, Color] = None,
) -> None:
"""Print chat messages with color to console.

Expand All @@ -29,7 +27,7 @@ def print_chat_messages_with_color(
Returns:
None
"""
role_to_color: dict[str, _COLOR_TYPE] = {
role_to_color: dict[str, Color] = {
"system": "red",
"user": "green",
"assistant": "blue",
Expand Down
14 changes: 9 additions & 5 deletions pyrit/memory/memory_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@
from typing import Optional
from uuid import UUID, uuid4

from pydantic import BaseModel, Extra, Field
from pydantic import BaseModel, ConfigDict, Field


class EmbeddingMemoryData(BaseModel, extra=Extra.forbid):
class EmbeddingMemoryData(BaseModel):
model_config = ConfigDict(extra="forbid")
uuid: UUID = Field(default_factory=uuid4)
embedding: list[float]
embedding_type_name: str


class ConversationMemoryEntry(BaseModel, extra=Extra.forbid):
class ConversationMemoryEntry(BaseModel):
"""
Represents a single memory entry.

Expand All @@ -29,6 +30,7 @@ class ConversationMemoryEntry(BaseModel, extra=Extra.forbid):
future references like scoring information.
"""

model_config = ConfigDict(extra="forbid")
role: str
content: str
conversation_id: str
Expand All @@ -44,11 +46,13 @@ def __str__(self):


# This class is convenient for serialization
class ConversationMemoryEntryList(BaseModel, extra=Extra.forbid):
class ConversationMemoryEntryList(BaseModel):
model_config = ConfigDict(extra="forbid")
conversations: list[ConversationMemoryEntry]


class ConversationMessageWithSimilarity(BaseModel, extra=Extra.forbid):
class ConversationMessageWithSimilarity(BaseModel):
model_config = ConfigDict(extra="forbid")
role: str
content: str
metric: str
Expand Down
17 changes: 11 additions & 6 deletions pyrit/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import Literal, Optional, Type, TypeVar

import yaml
from pydantic import BaseModel, Extra
from pydantic import BaseModel, ConfigDict


@dataclass
Expand All @@ -23,7 +23,8 @@ class Score:
score_explanation: str = ""


class PromptResponse(BaseModel, extra=Extra.forbid):
class PromptResponse(BaseModel):
model_config = ConfigDict(extra="forbid")
# The text response for the prompt
completion: str
# The original prompt
Expand Down Expand Up @@ -217,23 +218,27 @@ def apply_custom_metaprompt_parameters(self, **kwargs) -> str:
return final_prompt


class ChatMessage(BaseModel, extra=Extra.forbid):
class ChatMessage(BaseModel):
model_config = ConfigDict(extra="forbid")
role: str
content: str


class EmbeddingUsageInformation(BaseModel, extra=Extra.forbid):
class EmbeddingUsageInformation(BaseModel):
model_config = ConfigDict(extra="forbid")
prompt_tokens: int
total_tokens: int


class EmbeddingData(BaseModel, extra=Extra.forbid):
class EmbeddingData(BaseModel):
model_config = ConfigDict(extra="forbid")
embedding: list[float]
index: int
object: str


class EmbeddingResponse(BaseModel, extra=Extra.forbid):
class EmbeddingResponse(BaseModel):
model_config = ConfigDict(extra="forbid")
model: str
object: str
usage: EmbeddingUsageInformation
Expand Down
14 changes: 8 additions & 6 deletions tests/memory/test_file_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,10 @@ def test_explicit_embedding_model_set():


def test_default_embedding_model_set_none():
with NamedTemporaryFile(suffix=".json.memory") as tmp, patch(
"pyrit.memory.file_memory.default_memory_embedding_factory"
) as mock:
with (
NamedTemporaryFile(suffix=".json.memory") as tmp,
patch("pyrit.memory.file_memory.default_memory_embedding_factory") as mock,
):
mock.return_value = None
memory = FileMemory(filepath=tmp.name)
assert memory.memory_embedding is None
Expand All @@ -228,9 +229,10 @@ def test_default_embedding_model_set_none():
def test_default_embedding_model_set_correctly():
embedding = AzureTextEmbedding(api_key="testkey", api_base="testbase", model="deployment")

with NamedTemporaryFile(suffix=".json.memory") as tmp, patch(
"pyrit.memory.file_memory.default_memory_embedding_factory"
) as mock:
with (
NamedTemporaryFile(suffix=".json.memory") as tmp,
patch("pyrit.memory.file_memory.default_memory_embedding_factory") as mock,
):
mock.return_value = embedding
memory = FileMemory(filepath=tmp.name)
assert memory.memory_embedding is embedding
Loading