-
Notifications
You must be signed in to change notification settings - Fork 4.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Data cleaning scripts for dataset release (#2440)
- Loading branch information
1 parent
9cf3c8b
commit 24acac1
Showing
21 changed files
with
457 additions
and
298 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,6 @@ | |
|
||
import argparse | ||
import json | ||
from typing import Dict, Sequence, Optional | ||
|
||
|
||
if __name__ == "__main__": | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
119 changes: 119 additions & 0 deletions
119
fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/compute_stats.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
""" | ||
From colab: | ||
https://colab.research.google.com/drive/1oMdw_Lqgmd6DletSOLHsyD-Rc96cRShs?usp=sharing | ||
""" | ||
import argparse | ||
import datetime | ||
import json | ||
import os | ||
from pytz import timezone | ||
import time | ||
|
||
import kaleido | ||
import numpy as np | ||
import pandas as pd | ||
import plotly.express as px | ||
import plotly.graph_objects as go | ||
from tqdm import tqdm | ||
|
||
import plotly.io as pio | ||
|
||
pio.kaleido.scope.mathjax = None | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--in-file", type=str, required=True) | ||
parser.add_argument("--scale", type=int, required=True) | ||
args = parser.parse_args() | ||
|
||
filename = args.in_file | ||
scale = args.scale | ||
convs = json.load(open(filename)) | ||
df = pd.DataFrame(convs) | ||
df | ||
|
||
print(f"#ips: {df['user_id'].nunique() * scale}") | ||
print(f"#models: {df['model'].nunique()}") | ||
print(f"#language: {df['language'].nunique()}") | ||
print(f"#turns: {df['turn'].mean()}") | ||
|
||
model_counts = df["model"].value_counts() * scale | ||
# print("model counts", model_counts) | ||
fig = px.bar(x=model_counts.index, y=model_counts) | ||
fig.update_layout( | ||
xaxis_title=None, | ||
yaxis_title="Count", | ||
height=200, | ||
width=950, | ||
margin=dict(l=0, r=0, t=0, b=0), | ||
) | ||
fig.show() | ||
fig.write_image("model_count.pdf") | ||
|
||
|
||
model_counts = df["language"].value_counts().head(25) * scale | ||
fig = px.bar(x=model_counts.index, y=model_counts) | ||
fig.update_layout( | ||
xaxis_title=None, | ||
yaxis_title="Count", | ||
height=200, | ||
width=950, | ||
margin=dict(l=0, r=0, t=0, b=0), | ||
) | ||
fig.show() | ||
fig.write_image("language_count.pdf") | ||
|
||
chat_dates = [ | ||
datetime.datetime.fromtimestamp(x, tz=timezone("US/Pacific")).strftime("%Y-%m-%d") | ||
for x in df["tstamp"] | ||
] | ||
|
||
|
||
def to_remove(x): | ||
for d in ["08-09", "08-08", "08-07", "08-06", "08-05", "08-04"]: | ||
if d in x: | ||
return True | ||
return False | ||
|
||
|
||
chat_dates = [x for x in chat_dates if not to_remove(x)] | ||
|
||
chat_dates_counts = pd.value_counts(chat_dates) * scale | ||
print(f"mean #chat per day: {np.mean(chat_dates_counts):.2f}") | ||
|
||
fig = px.bar(x=chat_dates_counts.index, y=chat_dates_counts) | ||
fig.update_layout( | ||
xaxis_title="Dates", | ||
yaxis_title="Count", | ||
height=200, | ||
width=950, | ||
margin=dict(l=0, r=0, t=0, b=0), | ||
) | ||
fig.show() | ||
fig.write_image("daily_conversation_count.pdf") | ||
|
||
import transformers | ||
|
||
tokenizer = transformers.AutoTokenizer.from_pretrained( | ||
"lmsys/vicuna-7b-v1.5", use_fast=False | ||
) | ||
|
||
prompts = [] | ||
responses = [] | ||
for conv in df["conversation"]: | ||
for row in conv: | ||
if row["role"] == "user": | ||
prompts.append(row["content"]) | ||
else: | ||
responses.append(row["content"]) | ||
|
||
print(f"#prompts: {len(prompts)}") | ||
print(f"#responses: {len(responses)}") | ||
|
||
|
||
prompt_lens = [len(tokenizer(x).input_ids) for x in tqdm(prompts)] | ||
print() | ||
print(f"mean prompt len: {np.mean(prompt_lens):.2f}") | ||
|
||
response_lens = [len(tokenizer(x).input_ids) if x else 0 for x in tqdm(responses)] | ||
print() | ||
print(f"mean response len: {np.mean(response_lens):.2f}") |
148 changes: 148 additions & 0 deletions
148
fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/filter_bad_conv.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
""" | ||
Filter conversations for release. | ||
Dependency: | ||
pip install opencc-python-reimplementedpip install opencc-python-reimplemented | ||
Usage: | ||
python3 filter_bad_conv_lmsys_chat_1m.py --in clean_battle_conv_20230630_tagged_v1_pii.json | ||
""" | ||
import argparse | ||
from concurrent.futures import ProcessPoolExecutor | ||
from collections import defaultdict | ||
from enum import Enum, auto | ||
import json | ||
import os | ||
import random | ||
|
||
from tqdm import tqdm | ||
import opencc | ||
|
||
BLOCKED_WORDS_FILENAME = "blocked_words.json" | ||
blocked_words = [] | ||
frequency = defaultdict(lambda: 0) | ||
|
||
cc_converter = opencc.OpenCC("t2s") | ||
|
||
|
||
class TypeCode(Enum): | ||
CORRECT = auto() | ||
ANONYMIZED = auto() | ||
REDACTED = auto() | ||
BAD_FORMAT = auto() | ||
BLOCKED_WORD = auto() | ||
BLOCKED_MODEL = auto() | ||
TOO_SHORT = auto() | ||
TOO_FREQUENT = auto() | ||
|
||
|
||
def detect_type(conv): | ||
for key in ["conversation_a", "conversation_b", "conversation"]: | ||
if key not in conv: | ||
continue | ||
|
||
messages = [row["content"] for row in conv[key]] | ||
for msg in messages: | ||
if not isinstance(msg, str): | ||
return TypeCode.BAD_FORMAT | ||
|
||
if len(messages) == 0: | ||
return TypeCode.BAD_FORMAT | ||
|
||
user_prompts = [ | ||
row["content"].lower().strip() for row in conv[key] if row["role"] == "user" | ||
] | ||
|
||
for msg in messages: | ||
msg = cc_converter.convert(msg.lower()) | ||
if "<anonymized>" in msg: | ||
return TypeCode.ANONYMIZED | ||
if "<redacted>" in msg: | ||
return TypeCode.REDACTED | ||
|
||
for w in blocked_words: | ||
if w in msg: | ||
return TypeCode.BLOCKED_WORD | ||
|
||
return TypeCode.CORRECT | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--in-file", type=str, required=True) | ||
parser.add_argument("--sample", type=int) | ||
args = parser.parse_args() | ||
|
||
# Read conversations | ||
convs = json.load(open(args.in_file)) | ||
print(f"#conv: {len(convs)}") | ||
|
||
# Read blocked words | ||
if os.path.exists(BLOCKED_WORDS_FILENAME): | ||
blocked_words = json.load(open(BLOCKED_WORDS_FILENAME)) | ||
blocked_words = [cc_converter.convert(w) for w in blocked_words] | ||
|
||
# Start filter | ||
ct_bad_format = 0 | ||
ct_anonymized = 0 | ||
ct_redacted = 0 | ||
ct_error = 0 | ||
ct_lang_filter = 0 | ||
ct_flagged = 0 | ||
ct_blocked_word = 0 | ||
ct_blocked_model = 0 | ||
ct_too_short = 0 | ||
ct_too_frequent = 0 | ||
|
||
type_codes = [] | ||
with ProcessPoolExecutor() as executor: | ||
for result in tqdm(executor.map(detect_type, convs), total=len(convs)): | ||
type_codes.append(result) | ||
|
||
new_convs = [] | ||
for conv, type_code in zip(convs, type_codes): | ||
if type_code == TypeCode.BAD_FORMAT: | ||
ct_bad_format += 1 | ||
continue | ||
|
||
if type_code == TypeCode.ANONYMIZED: | ||
ct_anonymized += 1 | ||
continue | ||
elif type_code == TypeCode.REDACTED: | ||
ct_redacted += 1 | ||
continue | ||
elif type_code == TypeCode.BLOCKED_WORD: | ||
ct_blocked_word += 1 | ||
continue | ||
elif type_code == TypeCode.BLOCKED_MODEL: | ||
ct_blocked_model += 1 | ||
continue | ||
elif type_code == TypeCode.TOO_SHORT: | ||
ct_too_short += 1 | ||
continue | ||
elif type_code == TypeCode.TOO_FREQUENT: | ||
ct_too_frequent += 1 | ||
continue | ||
|
||
if "openai_moderation" in conv and conv["openai_moderation"]["flagged"]: | ||
ct_flagged += 1 | ||
continue | ||
|
||
if type_code in [TypeCode.CORRECT]: | ||
new_convs.append(conv) | ||
|
||
if args.sample: | ||
random.seed(42) | ||
random.shuffle(new_convs) | ||
new_convs = new_convs[: args.sample] | ||
|
||
print(f"ct_anonymized: {ct_anonymized}, ct_redacted: {ct_redacted}") | ||
print(f"ct_bad_format: {ct_bad_format}, ct_flagged: {ct_flagged}") | ||
print(f"ct_blocked_word: {ct_blocked_word}, ct_blocked_model: {ct_blocked_model}") | ||
print(f"ct_too_short: {ct_too_short}, ct_too_frequent: {ct_too_frequent}") | ||
print(f"new_conv: {len(new_convs)}") | ||
|
||
out_file = args.in_file.replace(".json", ".s1.json") | ||
print(f"Output to {out_file}") | ||
with open(out_file, "w") as fout: | ||
json.dump(new_convs, fout, indent=2, ensure_ascii=False) |
27 changes: 27 additions & 0 deletions
27
fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/final_post_processing.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import argparse | ||
import json | ||
|
||
from tqdm import tqdm | ||
import numpy as np | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--in-file", type=str, required=True) | ||
args = parser.parse_args() | ||
|
||
# Read conversations | ||
convs = json.load(open(args.in_file)) | ||
print(f"#conv: {len(convs)}") | ||
|
||
# Delete some fileds | ||
for c in convs: | ||
del c["tstamp"] | ||
del c["user_id"] | ||
|
||
# Write | ||
print(f"#out conv: {len(convs)}") | ||
out_file = args.in_file.replace(".json", ".s2.json") | ||
print(f"Output to {out_file}") | ||
with open(out_file, "w") as fout: | ||
json.dump(convs, fout, indent=2, ensure_ascii=False) |
Oops, something went wrong.