forked from huggingface/transformers
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
36cacf0
commit 4be387d
Showing
19 changed files
with
454 additions
and
274 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
accelerate==0.15.0 | ||
bitsandbytes | ||
deepspeed==0.7.7 | ||
./transformers | ||
-e ./transformers | ||
|
||
# TODO: Dev only | ||
isort>=5.5.4 | ||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,19 +1,23 @@ | ||
import pipelines | ||
from utils import benchmark_end_to_end, get_arg_parser, get_args, get_dummy_batch | ||
from typing import List, Optional | ||
|
||
from src.pipelines import get_pipeline_class | ||
from src.utils.arguments import parse_args | ||
from src.utils.benchmark import benchmark_end_to_end | ||
from src.utils.input import get_dummy_batch | ||
from src.utils.logging import configure_logging | ||
|
||
def main() -> None: | ||
# deepspeed.init_distributed("nccl") | ||
|
||
args = get_args(get_arg_parser()) | ||
def main(argv: Optional[List[str]] = None) -> None: | ||
args = parse_args(argv=argv) | ||
|
||
inputs = get_dummy_batch(args.batch_size, args.max_input_length) | ||
|
||
generate_kwargs = dict(max_new_tokens=args.max_new_tokens, do_sample=False) | ||
generate_kwargs = {"max_new_tokens": args.max_new_tokens, "do_sample": False} | ||
|
||
pipeline_class = getattr(pipelines, args.pipeline_class) | ||
pipeline_class = get_pipeline_class(args.pipeline_class) | ||
benchmark_end_to_end(args, pipeline_class, inputs, generate_kwargs) | ||
|
||
|
||
if __name__ == "__main__": | ||
configure_logging() | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,11 @@ | ||
from .ds_inference import DS_Inference_Pipeline | ||
from .hf import HF_CPU_Pipeline, HF_GPU_Pipeline | ||
from .pipeline import Pipeline | ||
def get_pipeline_class(name): | ||
if name == "HF_Pipeline": | ||
from src.pipelines.transformers import HF_Pipeline | ||
|
||
return HF_Pipeline | ||
elif name == "DS_Pipeline": | ||
from src.pipelines.ds import DS_Pipeline | ||
|
||
return DS_Pipeline | ||
else: | ||
raise NotImplementedError(f"Unsupported pipeline class: {name}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import os | ||
from argparse import Namespace | ||
|
||
import deepspeed | ||
import torch | ||
|
||
from src.pipelines.pipeline import Pipeline | ||
from src.utils.arguments import check_unused | ||
|
||
|
||
class DS_Pipeline(Pipeline): | ||
def __init__(self, args: Namespace) -> None: | ||
check_unused(args, {"device": torch.device("cuda")}, enforce=True) | ||
# TODO: Works with other dtypes? | ||
check_unused(args, {"dtype": torch.float16}) | ||
super().__init__(args) | ||
|
||
self.model = deepspeed.init_inference( | ||
self.model, | ||
mp_size=int(os.getenv("WORLD_SIZE", "1")), | ||
# base_dir="./", | ||
dtype=args.dtype, | ||
replace_with_kernel_inject=args.inject_kernel, | ||
enable_cuda_graph=args.cuda_graph, | ||
) |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.