Skip to content

Commit

Permalink
Data cleaning scripts for dataset release (#2440)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Sep 18, 2023
1 parent 9cf3c8b commit 24acac1
Show file tree
Hide file tree
Showing 21 changed files with 457 additions and 298 deletions.
13 changes: 12 additions & 1 deletion docs/commands/leaderboard.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,16 @@ python3 clean_battle_data.py

### Run Elo analysis
```
python3 elo_analysis.py --clean-battle-file clean_battle_20230523.json
python3 elo_analysis.py --clean-battle-file clean_battle_20230905.json
```

### Copy files to HF space
1. update plots
```
scp atlas:/data/lmzheng/FastChat/fastchat/serve/monitor/elo_results_20230905.pkl .
```

2. update table
```
wget https://huggingface.co/spaces/lmsys/chatbot-arena-leaderboard/raw/main/leaderboard_table_20230905.csv
```
1 change: 0 additions & 1 deletion fastchat/data/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import argparse
import json
from typing import Dict, Sequence, Optional


if __name__ == "__main__":
Expand Down
19 changes: 11 additions & 8 deletions fastchat/serve/monitor/clean_battle_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"palm",
"lamda",
"google",
"llama",
"NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.",
]

Expand All @@ -43,11 +44,7 @@

def get_log_files(max_num_files=None):
dates = []
for month in [4, 5, 6, 7]:
for day in range(1, 32):
dates.append(f"2023-{month:02d}-{day:02d}")

for month in [8]:
for month in [4, 5, 6, 7, 8, 9]:
for day in range(1, 32):
dates.append(f"2023-{month:02d}-{day:02d}")

Expand Down Expand Up @@ -85,7 +82,7 @@ def replace_model_name(old_name):
)


def clean_battle_data(log_files):
def clean_battle_data(log_files, exclude_model_names):
data = []
for filename in tqdm(log_files, desc="read files"):
for retry in range(5):
Expand Down Expand Up @@ -173,6 +170,11 @@ def clean_battle_data(log_files):
# Replace bard with palm
models = [replace_model_name(m) for m in models]

# Exclude certain models
if any(x in exclude_model_names for x in models):
ct_invalid += 1
continue

question_id = row["states"][0]["conv_id"]
conversation_a = to_openai_format(
row["states"][0]["messages"][row["states"][0]["offset"] :]
Expand All @@ -186,7 +188,7 @@ def clean_battle_data(log_files):
all_ips[ip] = len(all_ips)
user_id = all_ips[ip]

# Save the result
# Save the results
battles.append(
dict(
question_id=question_id,
Expand Down Expand Up @@ -228,10 +230,11 @@ def clean_battle_data(log_files):
parser.add_argument(
"--mode", type=str, choices=["simple", "conv_release"], default="simple"
)
parser.add_argument("--exclude-model-names", type=str, nargs="+")
args = parser.parse_args()

log_files = get_log_files(args.max_num_files)
battles = clean_battle_data(log_files)
battles = clean_battle_data(log_files, args.exclude_model_names or [])
last_updated_tstamp = battles[-1]["tstamp"]
cutoff_date = datetime.datetime.fromtimestamp(
last_updated_tstamp, tz=timezone("US/Pacific")
Expand Down
6 changes: 1 addition & 5 deletions fastchat/serve/monitor/clean_chat_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,7 @@

def get_log_files(max_num_files=None):
dates = []
for month in [4, 5, 6, 7]:
for day in range(1, 32):
dates.append(f"2023-{month:02d}-{day:02d}")

for month in [8]:
for month in [4, 5, 6, 7, 8, 9, 10]:
for day in range(1, 32):
dates.append(f"2023-{month:02d}-{day:02d}")

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""
From colab:
https://colab.research.google.com/drive/1oMdw_Lqgmd6DletSOLHsyD-Rc96cRShs?usp=sharing
"""
import argparse
import datetime
import json
import os
from pytz import timezone
import time

import kaleido
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from tqdm import tqdm

import plotly.io as pio

pio.kaleido.scope.mathjax = None

parser = argparse.ArgumentParser()
parser.add_argument("--in-file", type=str, required=True)
parser.add_argument("--scale", type=int, required=True)
args = parser.parse_args()

filename = args.in_file
scale = args.scale
convs = json.load(open(filename))
df = pd.DataFrame(convs)
df

print(f"#ips: {df['user_id'].nunique() * scale}")
print(f"#models: {df['model'].nunique()}")
print(f"#language: {df['language'].nunique()}")
print(f"#turns: {df['turn'].mean()}")

model_counts = df["model"].value_counts() * scale
# print("model counts", model_counts)
fig = px.bar(x=model_counts.index, y=model_counts)
fig.update_layout(
xaxis_title=None,
yaxis_title="Count",
height=200,
width=950,
margin=dict(l=0, r=0, t=0, b=0),
)
fig.show()
fig.write_image("model_count.pdf")


model_counts = df["language"].value_counts().head(25) * scale
fig = px.bar(x=model_counts.index, y=model_counts)
fig.update_layout(
xaxis_title=None,
yaxis_title="Count",
height=200,
width=950,
margin=dict(l=0, r=0, t=0, b=0),
)
fig.show()
fig.write_image("language_count.pdf")

chat_dates = [
datetime.datetime.fromtimestamp(x, tz=timezone("US/Pacific")).strftime("%Y-%m-%d")
for x in df["tstamp"]
]


def to_remove(x):
for d in ["08-09", "08-08", "08-07", "08-06", "08-05", "08-04"]:
if d in x:
return True
return False


chat_dates = [x for x in chat_dates if not to_remove(x)]

chat_dates_counts = pd.value_counts(chat_dates) * scale
print(f"mean #chat per day: {np.mean(chat_dates_counts):.2f}")

fig = px.bar(x=chat_dates_counts.index, y=chat_dates_counts)
fig.update_layout(
xaxis_title="Dates",
yaxis_title="Count",
height=200,
width=950,
margin=dict(l=0, r=0, t=0, b=0),
)
fig.show()
fig.write_image("daily_conversation_count.pdf")

import transformers

tokenizer = transformers.AutoTokenizer.from_pretrained(
"lmsys/vicuna-7b-v1.5", use_fast=False
)

prompts = []
responses = []
for conv in df["conversation"]:
for row in conv:
if row["role"] == "user":
prompts.append(row["content"])
else:
responses.append(row["content"])

print(f"#prompts: {len(prompts)}")
print(f"#responses: {len(responses)}")


prompt_lens = [len(tokenizer(x).input_ids) for x in tqdm(prompts)]
print()
print(f"mean prompt len: {np.mean(prompt_lens):.2f}")

response_lens = [len(tokenizer(x).input_ids) if x else 0 for x in tqdm(responses)]
print()
print(f"mean response len: {np.mean(response_lens):.2f}")
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
"""
Filter conversations for release.
Dependency:
pip install opencc-python-reimplementedpip install opencc-python-reimplemented
Usage:
python3 filter_bad_conv_lmsys_chat_1m.py --in clean_battle_conv_20230630_tagged_v1_pii.json
"""
import argparse
from concurrent.futures import ProcessPoolExecutor
from collections import defaultdict
from enum import Enum, auto
import json
import os
import random

from tqdm import tqdm
import opencc

BLOCKED_WORDS_FILENAME = "blocked_words.json"
blocked_words = []
frequency = defaultdict(lambda: 0)

cc_converter = opencc.OpenCC("t2s")


class TypeCode(Enum):
CORRECT = auto()
ANONYMIZED = auto()
REDACTED = auto()
BAD_FORMAT = auto()
BLOCKED_WORD = auto()
BLOCKED_MODEL = auto()
TOO_SHORT = auto()
TOO_FREQUENT = auto()


def detect_type(conv):
for key in ["conversation_a", "conversation_b", "conversation"]:
if key not in conv:
continue

messages = [row["content"] for row in conv[key]]
for msg in messages:
if not isinstance(msg, str):
return TypeCode.BAD_FORMAT

if len(messages) == 0:
return TypeCode.BAD_FORMAT

user_prompts = [
row["content"].lower().strip() for row in conv[key] if row["role"] == "user"
]

for msg in messages:
msg = cc_converter.convert(msg.lower())
if "<anonymized>" in msg:
return TypeCode.ANONYMIZED
if "<redacted>" in msg:
return TypeCode.REDACTED

for w in blocked_words:
if w in msg:
return TypeCode.BLOCKED_WORD

return TypeCode.CORRECT


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--in-file", type=str, required=True)
parser.add_argument("--sample", type=int)
args = parser.parse_args()

# Read conversations
convs = json.load(open(args.in_file))
print(f"#conv: {len(convs)}")

# Read blocked words
if os.path.exists(BLOCKED_WORDS_FILENAME):
blocked_words = json.load(open(BLOCKED_WORDS_FILENAME))
blocked_words = [cc_converter.convert(w) for w in blocked_words]

# Start filter
ct_bad_format = 0
ct_anonymized = 0
ct_redacted = 0
ct_error = 0
ct_lang_filter = 0
ct_flagged = 0
ct_blocked_word = 0
ct_blocked_model = 0
ct_too_short = 0
ct_too_frequent = 0

type_codes = []
with ProcessPoolExecutor() as executor:
for result in tqdm(executor.map(detect_type, convs), total=len(convs)):
type_codes.append(result)

new_convs = []
for conv, type_code in zip(convs, type_codes):
if type_code == TypeCode.BAD_FORMAT:
ct_bad_format += 1
continue

if type_code == TypeCode.ANONYMIZED:
ct_anonymized += 1
continue
elif type_code == TypeCode.REDACTED:
ct_redacted += 1
continue
elif type_code == TypeCode.BLOCKED_WORD:
ct_blocked_word += 1
continue
elif type_code == TypeCode.BLOCKED_MODEL:
ct_blocked_model += 1
continue
elif type_code == TypeCode.TOO_SHORT:
ct_too_short += 1
continue
elif type_code == TypeCode.TOO_FREQUENT:
ct_too_frequent += 1
continue

if "openai_moderation" in conv and conv["openai_moderation"]["flagged"]:
ct_flagged += 1
continue

if type_code in [TypeCode.CORRECT]:
new_convs.append(conv)

if args.sample:
random.seed(42)
random.shuffle(new_convs)
new_convs = new_convs[: args.sample]

print(f"ct_anonymized: {ct_anonymized}, ct_redacted: {ct_redacted}")
print(f"ct_bad_format: {ct_bad_format}, ct_flagged: {ct_flagged}")
print(f"ct_blocked_word: {ct_blocked_word}, ct_blocked_model: {ct_blocked_model}")
print(f"ct_too_short: {ct_too_short}, ct_too_frequent: {ct_too_frequent}")
print(f"new_conv: {len(new_convs)}")

out_file = args.in_file.replace(".json", ".s1.json")
print(f"Output to {out_file}")
with open(out_file, "w") as fout:
json.dump(new_convs, fout, indent=2, ensure_ascii=False)
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import argparse
import json

from tqdm import tqdm
import numpy as np


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--in-file", type=str, required=True)
args = parser.parse_args()

# Read conversations
convs = json.load(open(args.in_file))
print(f"#conv: {len(convs)}")

# Delete some fileds
for c in convs:
del c["tstamp"]
del c["user_id"]

# Write
print(f"#out conv: {len(convs)}")
out_file = args.in_file.replace(".json", ".s2.json")
print(f"Output to {out_file}")
with open(out_file, "w") as fout:
json.dump(convs, fout, indent=2, ensure_ascii=False)
Loading

0 comments on commit 24acac1

Please sign in to comment.