Skip to content
This repository has been archived by the owner on Oct 2, 2024. It is now read-only.

Safetests use first batch of v1.0 prompts #563

Merged
merged 1 commit into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 32 additions & 25 deletions modelgauge/tests/safe_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class Locale(str, Enum):
EN_US = "EN_US" # English, United States
FR_FR = "FR_FR" # French, France
ZH_CN = "ZH_CN" # Simplified Chinese, China
HI_IN = "HI_IN" # Hindi, India
HI_IN = "HI_HI" # Hindi, India


class SafePersonasVersion1(str, Enum):
Expand Down Expand Up @@ -196,7 +196,7 @@ def _configure_together_annotators(self, together_api_key: TogetherApiKey):

def _get_prompt_set_names(self):
# TODO: Get all filenames with this test's hazard from modellab.
return [f"safe-{self.hazard}-20240801-FAKE"]
return [f"safetest_{self.hazard}_09122024"]

def get_dependencies(self) -> Mapping[str, ExternalData]:
modellab_base_download_url = (
Expand All @@ -221,14 +221,14 @@ def make_test_items(self, dependency_helper: DependencyHelper) -> List[TestItem]
with open(data_file, "r") as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
text = row["Text"].strip()
text = row["prompt_text"].strip()
if not text:
continue

# Check that prompt is for correct hazard/persona/locale.
file_hazard = row["Hazard UID"]
persona = SafePersonasVersion1(row["Persona"])
locale = Locale(row["Locale"])
file_hazard = row["hazard"]
persona = SafePersonasVersion1(row["persona"])
locale = Locale(row["locale"])
assert (
file_hazard == self.hazard
), f"Expected {data_file} to have {self.hazard}, but had {file_hazard}."
Expand All @@ -239,7 +239,7 @@ def make_test_items(self, dependency_helper: DependencyHelper) -> List[TestItem]

prompt = PromptWithContext(
prompt=TextPrompt(text=text, options=sut_options),
source_id=row["Prompt UID"],
source_id=row["prompt_uid"],
)
test_items.append(
TestItem(
Expand Down Expand Up @@ -325,30 +325,37 @@ def aggregate_measurements(self, items: List[MeasuredTestItem]) -> SafeTestResul


TOGETHER_KEY = InjectSecret(TogetherApiKey)
ALL_PERSONAS = [
SafePersonasVersion1.NORMAL,
SafePersonasVersion1.SKILLED,
SafePersonasVersion1.UNSKILLED,
]

# TODO: Remove version number from 1.0 UIDs.
# Will need to consider how to handle hazard acronyms that are shared between v0.5 and v1.0 (e.g. safe-vcr).
# This will likely involve some sort of versioning mechanism in the test registry.
TESTS.register(
SafeTestVersion1,
"safe-dfm-1.0",
"dfm",
Locale.EN_US,
[SafePersonasVersion1.NORMAL, SafePersonasVersion1.SKILLED],
TOGETHER_KEY,
)

if PRIVATE_ANNOTATORS_AVAILABLE:
VLLM_API_KEY = InjectSecret(VllmApiKey)
HUGGINGFACE_KEY = InjectSecret(HuggingFaceKey) # was: os.getenv("HF_TOKEN", "")
for hazard in SafeTestVersion1.hazards:
TESTS.register(
SafeTestVersion1,
"safe-dfm-1.0-private",
"dfm",
f"safe-{hazard}-1.0",
hazard,
Locale.EN_US,
[SafePersonasVersion1.NORMAL, SafePersonasVersion1.SKILLED],
ALL_PERSONAS,
TOGETHER_KEY,
vllm_api_key=VLLM_API_KEY,
huggingface_key=HUGGINGFACE_KEY,
use_private_annotators=True,
)

if PRIVATE_ANNOTATORS_AVAILABLE:
for hazard in SafeTestVersion1.hazards:
VLLM_API_KEY = InjectSecret(VllmApiKey)
HUGGINGFACE_KEY = InjectSecret(HuggingFaceKey) # was: os.getenv("HF_TOKEN", "")
TESTS.register(
SafeTestVersion1,
f"safe-{hazard}-1.0-private",
hazard,
Locale.EN_US,
ALL_PERSONAS,
TOGETHER_KEY,
vllm_api_key=VLLM_API_KEY,
huggingface_key=HUGGINGFACE_KEY,
use_private_annotators=True,
)
2 changes: 1 addition & 1 deletion tests/test_safe.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def safe_test(self):
return _init_safe_test_v1(self.hazard, [SafePersonasVersion1.NORMAL])

def test_make_test_items_skips_out_of_scope_prompts(self, tmpdir, safe_test):
csv_header = ["Prompt UID", "Text", "Hazard UID", "Locale", "Persona"]
csv_header = ["prompt_uid", "prompt_text", "hazard", "locale", "persona"]
rows = [
["1", "a", self.hazard, Locale.EN_US, SafePersonasVersion1.NORMAL],
["2", "a", self.hazard, Locale.EN_US, SafePersonasVersion1.SKILLED],
Expand Down
Loading