Skip to content

Commit

Permalink
add dtype and seed (#2430)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ying1123 authored Sep 18, 2023
1 parent 318d070 commit 9cf3c8b
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 17 deletions.
4 changes: 2 additions & 2 deletions fastchat/llm_judge/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ To automate the evaluation process, we prompt strong LLMs like GPT-4 to act as j
- [Review Pre-Generated Model Answers and Judgments](#review-pre-generated-model-answers-and-judgments)
- [MT-Bench](#mt-bench)
- [Agreement Computation](#agreement-computation)
- [Dataset](#dataset)
- [Datasets](#datasets)
- [Citation](#citation)

## Install
Expand Down Expand Up @@ -133,7 +133,7 @@ We released 3.3K human annotations for model responses generated by 6 models in

This Colab [notebook](https://colab.research.google.com/drive/1ctgygDRJhVGUJTQy8-bRZCl1WNcT8De6?usp=sharing) shows how to compute the agreement between humans and GPT-4 judge with the dataset. Our results show that humans and GPT-4 judge achieve over 80\% agreement, the same level of agreement between humans.

## Dataset
## Datasets
- [Chatbot Arena Conversation Dataset](https://huggingface.co/datasets/lmsys/chatbot_arena_conversations)
- [MT-bench Human Annotation Dataset](https://huggingface.co/datasets/lmsys/mt_bench_human_judgments)

Expand Down
42 changes: 29 additions & 13 deletions fastchat/llm_judge/gen_model_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from fastchat.llm_judge.common import load_questions, temperature_config
from fastchat.model import load_model, get_conversation_template
from fastchat.utils import str_to_torch_dtype


def run_eval(
Expand All @@ -29,6 +30,7 @@ def run_eval(
num_gpus_per_model,
num_gpus_total,
max_gpu_memory,
dtype,
):
questions = load_questions(question_file, question_begin, question_end)
# random shuffle the questions to balance the loading
Expand All @@ -45,7 +47,7 @@ def run_eval(
else:
get_answers_func = get_model_answers

chunk_size = len(questions) // (num_gpus_total // num_gpus_per_model) // 2
chunk_size = len(questions) // (num_gpus_total // num_gpus_per_model)
ans_handles = []
for i in range(0, len(questions), chunk_size):
ans_handles.append(
Expand All @@ -58,6 +60,7 @@ def run_eval(
num_choices,
num_gpus_per_model,
max_gpu_memory,
dtype=dtype,
)
)

Expand All @@ -75,12 +78,14 @@ def get_model_answers(
num_choices,
num_gpus_per_model,
max_gpu_memory,
dtype,
):
model, tokenizer = load_model(
model_path,
device="cuda",
num_gpus=num_gpus_per_model,
max_gpu_memory=max_gpu_memory,
dtype=dtype,
load_8bit=False,
cpu_offloading=False,
debug=False,
Expand Down Expand Up @@ -192,7 +197,9 @@ def reorg_answer_file(answer_file):
required=True,
help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
)
parser.add_argument("--model-id", type=str, required=True)
parser.add_argument(
"--model-id", type=str, required=True, help="A custom name for the model."
)
parser.add_argument(
"--bench-name",
type=str,
Expand Down Expand Up @@ -234,6 +241,14 @@ def reorg_answer_file(answer_file):
type=str,
help="Maxmum GPU memory used for model weights per GPU.",
)
parser.add_argument(
"--dtype",
type=str,
choices=["float32", "float16", "bfloat16"],
help="Override the default dtype. If not set, it will use float16 on GPU and float32 on CPU.",
default=None,
)

args = parser.parse_args()

if args.num_gpus_total // args.num_gpus_per_model > 1:
Expand All @@ -250,17 +265,18 @@ def reorg_answer_file(answer_file):
print(f"Output to {answer_file}")

run_eval(
args.model_path,
args.model_id,
question_file,
args.question_begin,
args.question_end,
answer_file,
args.max_new_token,
args.num_choices,
args.num_gpus_per_model,
args.num_gpus_total,
args.max_gpu_memory,
model_path=args.model_path,
model_id=args.model_id,
question_file=question_file,
question_begin=args.question_begin,
question_end=args.question_end,
answer_file=answer_file,
max_new_token=args.max_new_token,
num_choices=args.num_choices,
num_gpus_per_model=args.num_gpus_per_model,
num_gpus_total=args.num_gpus_total,
max_gpu_memory=args.max_gpu_memory,
dtype=str_to_torch_dtype(args.dtype),
)

reorg_answer_file(answer_file)
11 changes: 11 additions & 0 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def load_model(
device: str = "cuda",
num_gpus: int = 1,
max_gpu_memory: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
load_8bit: bool = False,
cpu_offloading: bool = False,
gptq_config: Optional[GptqConfig] = None,
Expand Down Expand Up @@ -282,6 +283,9 @@ def load_model(
return model, tokenizer
kwargs["revision"] = revision

if dtype is not None: # Overwrite dtype if it is provided in the arguments.
kwargs["torch_dtype"] = dtype

# Load model
model, tokenizer = adapter.load_model(model_path, kwargs)

Expand Down Expand Up @@ -393,6 +397,13 @@ def add_model_args(parser):
type=str,
help="The maximum memory per GPU for storing model weights. Use a string like '13Gib'",
)
parser.add_argument(
"--dtype",
type=str,
choices=["float32", "float16", "bfloat16"],
help="Override the default dtype. If not set, it will use float16 on GPU and float32 on CPU.",
default=None,
)
parser.add_argument(
"--load-8bit", action="store_true", help="Use 8-bit quantization"
)
Expand Down
3 changes: 3 additions & 0 deletions fastchat/serve/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@
from rich.console import Console
from rich.live import Live
from rich.markdown import Markdown
import torch

from fastchat.model.model_adapter import add_model_args
from fastchat.modules.gptq import GptqConfig
from fastchat.modules.awq import AWQConfig
from fastchat.serve.inference import ChatIO, chat_loop
from fastchat.utils import str_to_torch_dtype


class SimpleChatIO(ChatIO):
Expand Down Expand Up @@ -208,6 +210,7 @@ def main(args):
args.device,
args.num_gpus,
args.max_gpu_memory,
str_to_torch_dtype(args.dtype),
args.load_8bit,
args.cpu_offloading,
args.conv_template,
Expand Down
2 changes: 2 additions & 0 deletions fastchat/serve/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def chat_loop(
device: str,
num_gpus: int,
max_gpu_memory: str,
dtype: Optional[torch.dtype],
load_8bit: bool,
cpu_offloading: bool,
conv_template: Optional[str],
Expand All @@ -312,6 +313,7 @@ def chat_loop(
device=device,
num_gpus=num_gpus,
max_gpu_memory=max_gpu_memory,
dtype=dtype,
load_8bit=load_8bit,
cpu_offloading=cpu_offloading,
gptq_config=gptq_config,
Expand Down
24 changes: 22 additions & 2 deletions fastchat/serve/model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
)
import torch
import torch.nn.functional as F
from transformers import set_seed
import uvicorn

from fastchat.constants import WORKER_HEART_BEAT_INTERVAL, ErrorCode, SERVER_ERROR_MSG
Expand All @@ -46,7 +47,12 @@
)
from fastchat.modules.gptq import GptqConfig
from fastchat.modules.awq import AWQConfig
from fastchat.utils import build_logger, pretty_print_semaphore, get_context_length
from fastchat.utils import (
build_logger,
pretty_print_semaphore,
get_context_length,
str_to_torch_dtype,
)


worker_id = str(uuid.uuid4())[:8]
Expand Down Expand Up @@ -190,13 +196,15 @@ def __init__(
device: str,
num_gpus: int,
max_gpu_memory: str,
dtype: Optional[torch.dtype] = None,
load_8bit: bool = False,
cpu_offloading: bool = False,
gptq_config: Optional[GptqConfig] = None,
awq_config: Optional[AWQConfig] = None,
stream_interval: int = 2,
conv_template: str = None,
conv_template: Optional[str] = None,
embed_in_truncate: bool = False,
seed: Optional[int] = None,
**kwargs,
):
super().__init__(
Expand All @@ -215,6 +223,7 @@ def __init__(
device=device,
num_gpus=num_gpus,
max_gpu_memory=max_gpu_memory,
dtype=dtype,
load_8bit=load_8bit,
cpu_offloading=cpu_offloading,
gptq_config=gptq_config,
Expand All @@ -227,6 +236,7 @@ def __init__(
self.generate_stream_func = get_generate_stream_function(self.model, model_path)
self.stream_interval = stream_interval
self.embed_in_truncate = embed_in_truncate
self.seed = seed

if not no_register:
self.init_heart_beat()
Expand All @@ -235,6 +245,8 @@ def generate_stream_gate(self, params):
self.call_ct += 1

try:
if self.seed is not None:
set_seed(self.seed)
for output in self.generate_stream_func(
self.model,
self.tokenizer,
Expand Down Expand Up @@ -475,6 +487,12 @@ def create_model_worker():
)
parser.add_argument("--stream-interval", type=int, default=2)
parser.add_argument("--no-register", action="store_true")
parser.add_argument(
"--seed",
type=int,
default=None,
help="Overwrite the random seed for each generation.",
)
args = parser.parse_args()
logger.info(f"args: {args}")

Expand Down Expand Up @@ -508,13 +526,15 @@ def create_model_worker():
device=args.device,
num_gpus=args.num_gpus,
max_gpu_memory=args.max_gpu_memory,
dtype=str_to_torch_dtype(args.dtype),
load_8bit=args.load_8bit,
cpu_offloading=args.cpu_offloading,
gptq_config=gptq_config,
awq_config=awq_config,
stream_interval=args.stream_interval,
conv_template=args.conv_template,
embed_in_truncate=args.embed_in_truncate,
seed=args.seed,
)
return args, worker

Expand Down
15 changes: 15 additions & 0 deletions fastchat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,18 @@ def get_context_length(config):
if val is not None:
return int(rope_scaling_factor * val)
return 2048


def str_to_torch_dtype(dtype: str):
import torch

if dtype is None:
return None
elif dtype == "float32":
return torch.float32
elif dtype == "float16":
return torch.float16
elif dtype == "bfloat16":
return torch.bfloat16
else:
raise ValueError(f"Unrecognized dtype: {dtype}")

0 comments on commit 9cf3c8b

Please sign in to comment.