Skip to content

Commit

Permalink
Refactor the runner
Browse files Browse the repository at this point in the history
  • Loading branch information
aoyulong committed Oct 14, 2024
1 parent d5ea05a commit 782a26f
Show file tree
Hide file tree
Showing 14 changed files with 693 additions and 566 deletions.
22 changes: 0 additions & 22 deletions examples/aquila/conf/config_infer.yaml

This file was deleted.

15 changes: 6 additions & 9 deletions examples/aquila/conf/inference/inference_aquila_7b.yaml
Original file line number Diff line number Diff line change
@@ -1,22 +1,19 @@
engine:
model: BAAI/Aquila-7B/
tokenizer: BAAI/Aquila-7B/
llm:
model: xxxx
trust_remote_code: true
tensor_parallel_size: 1
pipeline_parallel_size: 1
gpu_memory_utilization: 0.6
dtype: bfloat16
seed: 1234

data:
generate:
prompts: [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# prompts_path: null
top_p: 0.95
top_k: 100
max_tokens: 7
temperature: 0.9
sampling:
top_p: 0.95
temperature: 0.8
8 changes: 4 additions & 4 deletions flagscale/auto_tuner/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from omegaconf import DictConfig, OmegaConf

from flagscale.launcher.job_status import JobStatus
from flagscale.launcher.runner import SSHRunner
from flagscale.runner.runner_base import JobStatus
from flagscale.runner.runner_train import SSHTrainRunner

from .generate import Generator
from .platform import set_jiuding_platform_args
Expand Down Expand Up @@ -160,7 +160,7 @@ def tune(self):
raise ValueError(f"No strategy can run.")
best_task = self.generator.gen_best_task(best_strategy, self.orig_config)
best_task.action = "run"
runner = SSHRunner(best_task)
runner = SSHTrainRunner(best_task)
runner.run(monitor=True, interval=60)

def need_stop(self):
Expand Down Expand Up @@ -213,7 +213,7 @@ def run(self, task=None):
# Instantiate a runner and run the task
if task is None:
task = self.cur_task
self.runner = SSHRunner(task)
self.runner = SSHTrainRunner(task)
self.runner.run()
# set start time
self.task_start_time = time.time()
Expand Down
2 changes: 1 addition & 1 deletion flagscale/auto_tuner/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import subprocess
from types import SimpleNamespace

from flagscale.launcher.runner import parse_hostfile
from flagscale.runner.runner import parse_hostfile


def divisible(x, y):
Expand Down
179 changes: 51 additions & 128 deletions flagscale/inference/inference_aquila.py
Original file line number Diff line number Diff line change
@@ -1,140 +1,63 @@
import os
import yaml
import argparse
from typing import List, Union

import torch

from transformers import AutoTokenizer, LlamaForCausalLM, GenerationConfig
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams

from arguments import parse_args


def process_requests(prompts: List[str],
engine: LLMEngine,
sampling_params: SamplingParams):
"""Continuously process a list of prompts and handle the outputs."""
request_id = 0
while prompts:
prompt = prompts.pop(0)
engine.add_request(str(request_id), prompt, sampling_params)
request_id += 1
from omegaconf import OmegaConf, ListConfig
from vllm import LLM, SamplingParams


def get_config():
parser = argparse.ArgumentParser()
parser.add_argument("--config-path", type=str, required=True, help="Path to the configuration YAML file")
args = parser.parse_args()

config_path = args.config_path
# Open the YAML file and convert it into a dictionary
with open(config_path, 'r') as file:
config_dict = yaml.safe_load(file)

# Convert the dictionary into a DictConfig
config = OmegaConf.create(config_dict)
return config


def get_prompts(prompts):
print(prompts, type(prompts))
if isinstance(prompts, str) and os.path.isfile(prompts):
with open(prompts, 'r') as file:
return [line.strip() for line in file.readlines()]
elif isinstance(prompts, (list, ListConfig)):
return prompts
else:
raise ValueError("Prompts should be either a list of strings or a path to a file containing a list of strings.")

outputs: List[Union[RequestOutput]] = []
while engine.has_unfinished_requests():
step_outputs = engine.step()
for output in step_outputs:
if output.finished:
outputs.append(output)

outputs = sorted(outputs, key=lambda x: int(x.request_id))
return outputs
def inference():
# Get the configuration.
config = get_config()

# Get the prompts.
prompts = get_prompts(config.generate.prompts)

def inference(args: argparse.Namespace, prompts: List[str]):
"""Initialize the LLMEngine"""
engine_args = EngineArgs.from_cli_args(args)
llm_engine = LLMEngine.from_engine_args(engine_args)
# Create a sampling params object.
sampling_args = config.get("sampling", {})
sampling_params = SamplingParams(**sampling_args)

tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True)
llm_engine.tokenizer.tokenizer = tokenizer
# Create an LLM.
llm_args = config.get("llm", {})
model = llm_args.pop("model", None)
assert model is not None
llm = LLM(model, **llm_args)

"""Initialize the SamplingParams"""
sampling_params = SamplingParams(
n=args.n,
best_of=args.best_of,
frequency_penalty=args.frequency_penalty,
repetition_penalty=args.repetition_penalty,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
min_p=args.min_p,
seed=args.seed,
use_beam_search=args.use_beam_search,
length_penalty=args.length_penalty,
early_stopping=args.early_stopping,
stop=args.stop,
stop_token_ids=args.stop_token_ids,
include_stop_str_in_output=args.include_stop_str_in_output,
ignore_eos=args.ignore_eos,
max_tokens=args.max_tokens,
min_tokens=args.min_tokens,
logprobs=args.logprobs,
prompt_logprobs=args.prompt_logprobs,
detokenize=args.detokenize,
skip_special_tokens=args.skip_special_tokens,
spaces_between_special_tokens=args.spaces_between_special_tokens,
# logits_processors=,
# truncate_prompt_tokens=,
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)

outputs = process_requests(prompts, llm_engine, sampling_params)
# Print the outputs.
for output in outputs:
print("\n")
print("="*50)
print("=> RequestOutput:", output)
token_ids = output.outputs[0].token_ids
print("=> generated text:", tokenizer.decode(token_ids))


def generate(args: argparse.Namespace, prompts: List[str]):

model = LlamaForCausalLM.from_pretrained(
args.model,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
trust_remote_code=True
).to('cuda')
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True)

for prompt in prompts:
print("\n")
print("="*50)
print("=> prompt:", prompt)
tokens = tokenizer.encode_plus(prompt)["input_ids"]
tokens = torch.tensor(tokens)[None,].to(model.device)
input_length = len(tokens[0])
generation_config = GenerationConfig(
do_sample=True,
eos_token_id=tokenizer.convert_tokens_to_ids('<|extra_204|>'),
pad_token_id=tokenizer.convert_tokens_to_ids('<|endoftext|>'),
max_new_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
)
out = model.generate(
tokens,
generation_config,
return_dict_in_generate=True,
output_scores=True,
)
out_ids = out["sequences"][0][input_length:].cpu().numpy()
out_text = tokenizer.decode(out_ids.tolist())
print("=> generated text:", out_text)
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


if __name__ == '__main__':
args = parse_args()

prompts = []
if args.prompts_path is not None:
with open(args.prompts_path, "r") as f:
while True:
prompt = f.readline()
if not prompt:
break
prompts.append(prompt[:-1]) # remove the last '\n' of prompt
elif len(args.prompts) > 1:
prompts = args.prompts
else:
raise ValueError("Pleace set right prompts_path or prompts data.")

"""
vllm inference
"""
inference(args, prompts)

"""
transformers inference
"""
# generate(args, prompts)
# Run the inference
inference()
7 changes: 0 additions & 7 deletions flagscale/launcher/job_status.py

This file was deleted.

File renamed without changes.
22 changes: 22 additions & 0 deletions flagscale/runner/runner_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from abc import ABC, abstractmethod
from omegaconf import DictConfig
from enum import Enum


class JobStatus(Enum):
RUNNING = "Running"
TRANSITIONAL = "Transitional (Stopping or Starting)"
COMPLETED_OR_IDLE = "Completed or Not Started"


class RunnerBase(ABC):
def __init__(self, config: DictConfig):
self.config = config

@abstractmethod
def run(self, *args, **kwargs):
raise NotImplementedError

def stop(self, *args, **kwargs):
"""Optional method to override."""
pass
Loading

0 comments on commit 782a26f

Please sign in to comment.