forked from pytorch/torchtune
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpreparing_sys_compliance_data.py
32 lines (24 loc) · 1.12 KB
/
preparing_sys_compliance_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from datasets import load_dataset
import re
dataset = load_dataset('json', data_files='../system-prompt-compliance/output/full_outputs.jsonl', split='train')
# creating a column name new_label for the dataset
def create_new_label(row):
# Try to find policies in different formats
# First try numbered format (1., 2., etc)
policies = re.findall(r'\d+\.\s\*\*[^:]+:\*\*\s+([^\n]+)', row['system_prompt'])
# If no numbered policies found, try bullet point format
if not policies:
# Match text after "**" and ":" up to the end of line
policies = re.findall(r'\*\s\*\*[^:]+:\*\*\s+([^\n]+)', row['system_prompt'])
n = len(policies)
result = ["pass"] * n
explanation = [""] * n
# Find the index of the target policy
for i, policy in enumerate(policies):
if policy.strip() == row['target_policy'].strip():
if row["label"] != "compliant":
result[i] = "fail"
explanation[i] = row["explanation"]
break
return {"training_label": result, "training_explanation": explanation}
dataset = dataset.map(create_new_label)