Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Example: DeepSpeed deferred init with opt-30b #2419

Merged
merged 20 commits into from
Jun 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions examples/large_models/deepspeed/Readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Loading large Huggingface models on Multiple GPUs

This document briefs on serving large HuggingFace (HF) models on multiple GPUs using deepspeed. We are using facebook/opt-30b in this example

To run this example we need to have deepspeed installed. This has been added to the requirement.txt which can be bundled during model packaging.


```bash
pip install deepspeed

```

### Step 1: Download model

```bash
python ../utils/Download_model.py --model_path model --model_name facebook/opt-30b --revision main
```

The script prints the path where the model is downloaded as below.

`opt/model/models--facebook--opt-30b/snapshots/ceea0a90ac0f6fae7c2c34bcb40477438c152546`

### Step 2: Generate mar or tgz file

```bash
torch-model-archiver --model-name opt --version 1.0 --handler custom_handler.py --extra-files ds-config.json -r requirements.txt --config-file opt/model-config.yaml --archive-format tgz
```

### Step 3: Add the tgz file to model store

```bash
mkdir model_store
mv opt.tar.gz model_store
```

### Step 4: Start torchserve

```bash
torchserve --start --ncs --model-store model_store --models opt.tar.gz
```

### Step 5: Run inference

```bash
curl "http://localhost:8080/predictions/opt" -T sample_text.txt
```
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

from ts.context import Context
from ts.handler_utils.distributed.deepspeed import get_ds_engine
Expand Down Expand Up @@ -36,15 +36,21 @@ def initialize(self, ctx: Context):
model_dir = ctx.system_properties.get("model_dir")
self.max_length = int(ctx.model_yaml_config["handler"]["max_length"])
self.max_new_tokens = int(ctx.model_yaml_config["handler"]["max_new_tokens"])
model_name = ctx.model_yaml_config["handler"]["model_name"]
model_path = ctx.model_yaml_config["handler"]["model_path"]
seed = int(ctx.model_yaml_config["handler"]["manual_seed"])
torch.manual_seed(seed)

self.tokenizer = AutoTokenizer.from_pretrained(model_dir, padding_side="left")
logger.info("Model %s loading tokenizer", ctx.model_name)

self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(
model_dir, torch_dtype=torch.float16
)
self.model.eval()
config = AutoConfig.from_pretrained(model_name)
with torch.device("meta"):
self.model = AutoModelForCausalLM.from_config(
config, torch_dtype=torch.float16
)
self.model = self.model.eval()

ds_engine = get_ds_engine(self.model, ctx)
self.model = ds_engine.module
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
"dtype": "torch.float16",
"replace_with_kernel_inject": true,
"tensor_parallel": {
"tp_size": 2
"tp_size": 4
}
}
}
94 changes: 0 additions & 94 deletions examples/large_models/deepspeed/opt/Readme.md

This file was deleted.

7 changes: 5 additions & 2 deletions examples/large_models/deepspeed/opt/model-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,19 @@ responseTimeout: 1200
parallelType: "tp"
deviceType: "gpu"
# example of user specified GPU deviceIds
deviceIds: [2,3] # seting CUDA_VISIBLE_DEVICES
deviceIds: [0,1,2,3] # seting CUDA_VISIBLE_DEVICES

torchrun:
nproc-per-node: 2
nproc-per-node: 4

# TorchServe Backend parameters
deepspeed:
config: ds-config.json
checkpoint: checkpoints.json

handler:
model_name: "facebook/opt-30b"
model_path: "/home/ubuntu/serve/examples/large_models/deepspeed/opt/model/models--facebook--opt-30b/snapshots/ceea0a90ac0f6fae7c2c34bcb40477438c152546"
max_length: 50
max_new_tokens: 10
manual_seed: 40
2 changes: 0 additions & 2 deletions examples/large_models/deepspeed/opt/requirements.txt

This file was deleted.

2 changes: 2 additions & 0 deletions examples/large_models/deepspeed/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
transformers>=4.30.1
deepspeed>=0.9.4
31 changes: 23 additions & 8 deletions ts/handler_utils/distributed/deepspeed.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,29 @@
import json
import logging
import os
from pathlib import Path

import deepspeed

from ts.context import Context


def create_checkpoints_json(model_path, checkpoints_json):
checkpoint_files = file_list = [
str(entry)
for entry in Path(model_path).rglob("*.[bp][it][n]")
if entry.is_file()
]
data = {"type": "ds_model", "checkpoints": checkpoint_files, "version": 1.0}
print(f"Creating deepspeed checkpoint file {checkpoints_json}")
json.dump(data, open(checkpoints_json, "w"))


def get_ds_engine(model, ctx: Context):
model_dir = ctx.system_properties.get("model_dir")
ds_config = None
checkpoint = None
ds_config, checkpoint = None, None
model_path = ctx.model_yaml_config["handler"]["model_path"]

if "deepspeed" in ctx.model_yaml_config:
# config: the deepspeed config json file path.
# deepspeed config parameters:
Expand All @@ -23,17 +37,18 @@ def get_ds_engine(model, ctx: Context):
f"{ctx.model_name} has no deepspeed config file {ds_config}"
)

if "checkpoint" in ctx.model_yaml_config:
if "checkpoint" in ctx.model_yaml_config["deepspeed"]:
checkpoint = os.path.join(
model_dir, ctx.model_yaml_config["deepspeed"]["checkpoint"]
)
if not os.path.exists(checkpoint):
raise ValueError(
f"{ctx.model_name} has no deepspeed checkpoint file {checkpoint}"
)
create_checkpoints_json(model_path, checkpoint)

logging.debug("Creating DeepSpeed engine")
ds_engine = deepspeed.init_inference(
model, config=ds_config, checkpoint=checkpoint
model,
config=ds_config,
base_dir=model_path,
checkpoint=checkpoint,
)
return ds_engine
else:
Expand Down