Skip to content

Commit

Permalink
build: update versions of pre-commit hooks (#47)
Browse files Browse the repository at this point in the history
* build: update versions of pre-commit hooks

* build: update mypy pre-commit hook to install type hints

* refactor: migrate `pydantic` objects

The `extra` field was deprecated, see https://docs.pydantic.dev/latest/migration/#changes-to-dataclasses

* refactor: fix type hint issue with termcolor `Color` type

* style: fix spacing and formatting
  • Loading branch information
dlmgary authored Feb 16, 2024
1 parent 9de40a8 commit afc45fc
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 48 deletions.
23 changes: 10 additions & 13 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,28 @@ repos:
- id: check-added-large-files
- id: detect-private-key

- repo: https://github.com/psf/black
rev: 23.7.0
# https://black.readthedocs.io/en/stable/integrations/source_version_control.html
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.2.0
hooks:
- id: black
language_version: python3.10

- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
rev: 7.0.0
hooks:
- id: flake8

- repo: https://github.com/pycqa/pylint
rev: v2.16.2
rev: v3.0.3
hooks:
- id: pylint
args: [ --disable=all, --enable=unused-import ]
args: [--disable=all, --enable=unused-import]
exclude: NOTICE.txt

# Calling local version of mypy rather the one from GitHub https://github.com/pre-commit/mirrors-mypy
# because there is no way to sync the configs in pyproject.toml and the .pre-commit-config.yaml
# See https://stackoverflow.com/a/75003826/4956355
- repo: local
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.8.0
hooks:
- id: mypy
name: mypy
entry: mypy
language: system
types: [ python ]
args: [--install-types, --non-interactive, --ignore-missing-imports]
26 changes: 13 additions & 13 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,25 @@ name = "pyrit"
version = "0.1.0"
description = "The Python Risk Identification Tool for LLMs (PyRIT) is a library used to assess the robustness of LLMs"
authors = [
{name = "Microsoft AI Red Team", email = "airedteam@microsoft.com"},
{name = "dlmgary"},
{name = "amandajean119"},
{name = "microsiska"},
{name = "rdheekonda"},
{name = "rlundeen2"},
{name = "romanlutz"},
{name = "jbolor21"},
{name = "nina-msft"},
{ name = "Microsoft AI Red Team", email = "airedteam@microsoft.com" },
{ name = "dlmgary" },
{ name = "amandajean119" },
{ name = "microsiska" },
{ name = "rdheekonda" },
{ name = "rlundeen2" },
{ name = "romanlutz" },
{ name = "jbolor21" },
{ name = "nina-msft" },
]
readme = "README.md"
license = {text = "MIT"}
license = { text = "MIT" }
keywords = [
"llm",
"ai-security",
"ai-red-team",
"ai-robustness",
"ai-robustness-testing",
"ai-risk-assessment"
"ai-risk-assessment",
]
classifiers = [
"Development Status :: 3 - Alpha",
Expand All @@ -48,15 +48,15 @@ dependencies = [
"tokenizers>=0.15.0",
"torch==2.1.2",
"transformers>=4.36.0",
"types-requests>=2.31.0.2"
"types-requests>=2.31.0.2",
]

[project.optional-dependencies]
dev = [
"black>=23.3.0",
"flake8>=6.0.0",
"flake8-copyright>=0.2.0",
"mypy>=1.2.0",
"mypy>=1.8.0",
"pre-commit>=3.3.3",
"pytest>=7.3.1",
"pytest-cov>=4.0.0",
Expand Down
8 changes: 3 additions & 5 deletions pyrit/common/print.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,16 @@
import textwrap

import termcolor

from pyrit.memory.memory_models import ConversationMemoryEntry
from pyrit.models import ChatMessage

_COLOR_TYPE = termcolor._types.Color
from termcolor._types import Color


def print_chat_messages_with_color(
messages: list[ChatMessage | ConversationMemoryEntry],
max_content_character_width: int = 80,
left_padding_width: int = 20,
custom_colors: dict[str, _COLOR_TYPE] = None,
custom_colors: dict[str, Color] = None,
) -> None:
"""Print chat messages with color to console.
Expand All @@ -29,7 +27,7 @@ def print_chat_messages_with_color(
Returns:
None
"""
role_to_color: dict[str, _COLOR_TYPE] = {
role_to_color: dict[str, Color] = {
"system": "red",
"user": "green",
"assistant": "blue",
Expand Down
14 changes: 9 additions & 5 deletions pyrit/memory/memory_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@
from typing import Optional
from uuid import UUID, uuid4

from pydantic import BaseModel, Extra, Field
from pydantic import BaseModel, ConfigDict, Field


class EmbeddingMemoryData(BaseModel, extra=Extra.forbid):
class EmbeddingMemoryData(BaseModel):
model_config = ConfigDict(extra="forbid")
uuid: UUID = Field(default_factory=uuid4)
embedding: list[float]
embedding_type_name: str


class ConversationMemoryEntry(BaseModel, extra=Extra.forbid):
class ConversationMemoryEntry(BaseModel):
"""
Represents a single memory entry.
Expand All @@ -29,6 +30,7 @@ class ConversationMemoryEntry(BaseModel, extra=Extra.forbid):
future references like scoring information.
"""

model_config = ConfigDict(extra="forbid")
role: str
content: str
conversation_id: str
Expand All @@ -44,11 +46,13 @@ def __str__(self):


# This class is convenient for serialization
class ConversationMemoryEntryList(BaseModel, extra=Extra.forbid):
class ConversationMemoryEntryList(BaseModel):
model_config = ConfigDict(extra="forbid")
conversations: list[ConversationMemoryEntry]


class ConversationMessageWithSimilarity(BaseModel, extra=Extra.forbid):
class ConversationMessageWithSimilarity(BaseModel):
model_config = ConfigDict(extra="forbid")
role: str
content: str
metric: str
Expand Down
17 changes: 11 additions & 6 deletions pyrit/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import Literal, Optional, Type, TypeVar

import yaml
from pydantic import BaseModel, Extra
from pydantic import BaseModel, ConfigDict


@dataclass
Expand All @@ -23,7 +23,8 @@ class Score:
score_explanation: str = ""


class PromptResponse(BaseModel, extra=Extra.forbid):
class PromptResponse(BaseModel):
model_config = ConfigDict(extra="forbid")
# The text response for the prompt
completion: str
# The original prompt
Expand Down Expand Up @@ -217,23 +218,27 @@ def apply_custom_metaprompt_parameters(self, **kwargs) -> str:
return final_prompt


class ChatMessage(BaseModel, extra=Extra.forbid):
class ChatMessage(BaseModel):
model_config = ConfigDict(extra="forbid")
role: str
content: str


class EmbeddingUsageInformation(BaseModel, extra=Extra.forbid):
class EmbeddingUsageInformation(BaseModel):
model_config = ConfigDict(extra="forbid")
prompt_tokens: int
total_tokens: int


class EmbeddingData(BaseModel, extra=Extra.forbid):
class EmbeddingData(BaseModel):
model_config = ConfigDict(extra="forbid")
embedding: list[float]
index: int
object: str


class EmbeddingResponse(BaseModel, extra=Extra.forbid):
class EmbeddingResponse(BaseModel):
model_config = ConfigDict(extra="forbid")
model: str
object: str
usage: EmbeddingUsageInformation
Expand Down
14 changes: 8 additions & 6 deletions tests/memory/test_file_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,10 @@ def test_explicit_embedding_model_set():


def test_default_embedding_model_set_none():
with NamedTemporaryFile(suffix=".json.memory") as tmp, patch(
"pyrit.memory.file_memory.default_memory_embedding_factory"
) as mock:
with (
NamedTemporaryFile(suffix=".json.memory") as tmp,
patch("pyrit.memory.file_memory.default_memory_embedding_factory") as mock,
):
mock.return_value = None
memory = FileMemory(filepath=tmp.name)
assert memory.memory_embedding is None
Expand All @@ -228,9 +229,10 @@ def test_default_embedding_model_set_none():
def test_default_embedding_model_set_correctly():
embedding = AzureTextEmbedding(api_key="testkey", api_base="testbase", model="deployment")

with NamedTemporaryFile(suffix=".json.memory") as tmp, patch(
"pyrit.memory.file_memory.default_memory_embedding_factory"
) as mock:
with (
NamedTemporaryFile(suffix=".json.memory") as tmp,
patch("pyrit.memory.file_memory.default_memory_embedding_factory") as mock,
):
mock.return_value = embedding
memory = FileMemory(filepath=tmp.name)
assert memory.memory_embedding is embedding

0 comments on commit afc45fc

Please sign in to comment.