EXPERIMENTAL AND NOT PRODUCTION READY! Many rough edges.
Runs OPT-66b inference on a cluster composed of g4dn nodes (in my tests, 3 x g4dn.12xlarge, giving a total of 12 GPUs). You can also run it on 12 x g4dn.4xlarge.
python run_on_every_node.py download_model "s3://large-dl-models-mirror/models--anyscale--opt-66b-resharded/main/" "~/model"
python deepspeed_inference_actors.py --name "facebook/opt-66b" --checkpoint_path "~/model" --batch_size 1 --ds_inference --use_kernel --use_meta_tensor --num_worker_groups 1 --num_gpus_per_worker_group 12
This repository demonstrates how to use DeepSpeed Inference with Ray for scalable batch inference. The combination of these two tools allows for efficient generation of text with large language models, including models as large as OPT-66b.
DeepSpeed Inference utilizes automatic model parallelism to distribute the model across multiple GPUs. Ray handles the scheduling and orchestration of the workload.
There are three key parts to the code:
deepspeed_inference_actors.py
(the entrypoint) generates a sample Ray Dataset and usesray.train.batch_predictor.BatchPredictor
with a customDeepSpeedPredictor
. TheBatchPredictor
spawnsnum_worker_groups
DeepSpeedPredictor
actors, each recieving a share of the data.deepspeed_predictor.py
contains the code for theDeepSpeedPredictor
. EachDeepSpeedPredictor
actor spawnsnum_gpus_per_worker_group
worker actors (PredictionWorker
), connected together via atorch.distributed
backend, as required by DeepSpeed. Once initialized, the DeepSpeed model is ready for prediction.deepspeed_utils.py
contains code based on a DeepSpeed example that is used byPredictionWorkers
.
In other words, a DeepSpeedPredictor
creates a worker group of PredictionWorker
, which share a single model. A worker group is inelastic (if one worker fails, the entire group fails). This is similar to how Ray Train works (in fact, the logic can be implemented using Ray Train private APIs instead of PredictionWorker
).
- If there are multiple worker groups scheduled on one node, this will result in workers using the same CUDA devices and thus leading to a crash. Therefore, it's best to either use 1 GPU nodes, or make sure that the number of workers in a group divided by the number of nodes is equal to the number of GPUs on the nodes.
- Certain models obtained from Hugging Face hub will cause exceptions due to a bug in DeepSpeed. The solution is to reshard the checkpoints of those models to ensure that all layers are stored in contiguous files. The relevant code is included in
huggingface_utils.py
.
Key packages:
accelerate==0.17.1
deepspeed==0.8.3
ray @ https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp38-cp38-manylinux2014_x86_64.whl
torch==2.0.0
transformers==4.27.2
All packages:
absl-py==1.4.0
accelerate==0.17.1
adal==1.2.7
aim==3.16.1
aim-ui==3.16.1
aimrecords==0.0.7
aimrocks==0.3.1
aiofiles==22.1.0
aiohttp==3.8.4
aiohttp-cors==0.7.0
aiorwlock==1.3.0
aiosignal==1.3.1
aiosqlite==0.18.0
ale-py==0.8.1
alembic==1.10.2
anyio==3.6.2
anyscale @ file:///home/ray/anyscale-0.0.0.dev0.tar.gz
anyscale-node-provider @ file:///home/ray/anyscale_node_provider-0.0.1.tar.gz
applicationinsights==0.11.10
argcomplete==1.12.3
argon2-cffi==21.3.0
argon2-cffi-bindings==21.2.0
arrow==1.2.3
asttokens==2.2.1
astunparse==1.6.3
async-timeout==4.0.2
attrs==22.2.0
autocfg==0.0.8
autogluon.common==0.7.0
autogluon.core==0.7.0
autograd==1.5
autopage==0.5.1
AutoROM==0.6.0
AutoROM.accept-rom-license==0.6.0
awscli==1.25.6
awscliv2==2.2.0
ax-platform==0.3.1
azure-cli-core==2.40.0
azure-cli-telemetry==1.0.8
azure-common==1.1.28
azure-core==1.26.3
azure-identity==1.10.0
azure-mgmt-compute==23.1.0
azure-mgmt-core==1.3.2
azure-mgmt-network==19.0.0
azure-mgmt-resource==20.0.0
Babel==2.12.1
backcall==0.2.0
backoff==1.10.0
backports.zoneinfo==0.2.1
base58==2.0.1
bayesian-optimization==1.2.0
bcrypt==4.0.1
beautifulsoup4==4.12.0
bitsandbytes==0.37.2
black==23.1.0
bleach==6.0.0
blessed==1.20.0
blobfile==2.0.1
boto3==1.26.95
botocore==1.29.95
botorch==0.8.3
cached-property==1.5.2
cachetools==5.3.0
catboost==1.1.1
certifi==2022.12.7
cffi @ file:///tmp/abs_98z5h56wf8/croots/recipe/cffi_1659598650955/work
chardet==5.1.0
charset-normalizer==3.1.0
chess==1.7.0
chex==0.1.6
click==8.1.3
cliff==4.2.0
cloudpickle==2.2.1
cma==2.7.0
cmaes==0.9.1
cmake==3.26.0
cmd2==2.4.3
colorama==0.4.6
coloredlogs==15.0.1
colorful==0.5.5
colorlog==6.7.0
comet-ml==3.31.9
comm==0.1.2
commonmark==0.9.1
conda==23.1.0
conda-content-trust @ file:///tmp/abs_5952f1c8-355c-4855-ad2e-538535021ba5h26t22e5/croots/recipe/conda-content-trust_1658126371814/work
conda-package-handling @ file:///croot/conda-package-handling_1666940373510/work
configobj==5.0.8
ConfigSpace==0.4.18
contourpy==1.0.7
coolname==2.2.0
cryptography @ file:///croot/cryptography_1673298753778/work
cycler==0.11.0
Cython==0.29.32
databricks-cli==0.17.5
DataProperty==0.55.0
datasets==2.10.1
debugpy==1.6.6
decorator==5.1.1
decord==0.6.0
deepspeed==0.8.3
defusedxml==0.7.1
Deprecated==1.2.13
diffusers @ git+https://github.com/huggingface/diffusers.git@7fe88613fa15d230d59482889c440c7befa17c25
dill==0.3.6
distlib==0.3.6
dm-tree==0.1.8
docker==6.0.1
docker-pycreds==0.4.0
docutils==0.16
dopamine-rl==4.0.5
dragonfly-opt==0.1.6
dulwich==0.21.3
einops==0.3.0
entrypoints==0.4
etils==1.1.1
evaluate==0.4.0
everett==3.1.0
exceptiongroup==1.1.1
executing==1.2.0
executor==23.2
expiringdict==1.2.2
fastapi==0.95.0
fasteners==0.18
fastjsonschema==2.16.3
filelock==3.10.0
FLAML==1.1.1
Flask==2.2.3
flatbuffers==2.0.7
flax==0.6.7
fonttools==4.39.2
fqdn==1.5.1
freezegun==1.1.0
frozenlist==1.3.3
fsspec==2023.3.0
ftfy==6.1.1
future==0.18.3
gast==0.4.0
gin-config==0.5.0
gitdb==4.0.10
GitPython==3.1.31
glfw==2.5.7
gluoncv==0.10.1.post0
google-api-core==2.11.0
google-api-python-client==1.7.8
google-auth==2.16.2
google-auth-httplib2==0.1.0
google-auth-oauthlib==0.4.6
google-cloud-compute==1.10.1
google-cloud-core==2.3.2
google-cloud-resource-manager==1.9.0
google-cloud-secret-manager==2.16.0
google-cloud-storage==2.7.0
google-crc32c==1.5.0
google-oauth==1.0.1
google-pasta==0.2.0
google-resumable-media==2.4.1
googleapis-common-protos==1.58.0
gpustat==1.0.0
GPy==1.10.0
gpytorch==1.9.1
graphviz==0.8.4
greenlet==2.0.2
grpc-google-iam-v1==0.12.6
grpcio==1.51.3
grpcio-status==1.48.2
grpcio-tools==1.51.3
gunicorn==20.1.0
gym==0.26.2
gym-notices==0.0.8
Gymnasium==0.26.3
gymnasium-notices==0.0.1
h11==0.14.0
h5py==3.7.0
halo==0.0.31
HEBO==0.3.2
higher==0.2.1
hjson==3.1.0
hpbandster==0.7.4
httplib2==0.21.0
huggingface-hub==0.13.3
humanfriendly==10.0
humanize==4.6.0
hyperopt==0.2.5
idna==3.4
imageio==2.26.1
imageio-ffmpeg==0.4.5
importlib-metadata==6.1.0
importlib-resources==5.12.0
iniconfig==2.0.0
ipykernel==6.22.0
ipython==8.11.0
ipython-genutils==0.2.0
ipywidgets==8.0.4
isodate==0.6.1
isoduration==20.11.0
isort==5.12.0
itsdangerous==2.1.2
jax==0.4.6
jaxlib==0.4.6
jedi==0.18.2
Jinja2==3.1.2
jmespath==0.10.0
joblib==1.2.0
json5==0.9.11
jsonlines==3.1.0
jsonpatch==1.32
jsonpointer==2.3
jsonschema==4.17.3
jupyter-events==0.6.3
jupyter-ydoc==0.2.3
jupyter_client==8.1.0
jupyter_core==5.3.0
jupyter_server==2.5.0
jupyter_server_fileid==0.8.0
jupyter_server_terminals==0.4.4
jupyter_server_ydoc==0.6.1
jupyterlab==3.6.1
jupyterlab-pygments==0.2.2
jupyterlab-widgets==3.0.5
jupyterlab_server==2.20.0
kaggle-environments==1.7.11
keras==2.11.0
kiwisolver==1.4.4
knack==0.10.1
kubernetes==26.1.0
lazy_loader==0.1
libclang==15.0.6.1
libtorrent==2.0.7
lightgbm==3.3.5
lightgbm-ray==0.1.8
lightning-bolts==0.4.0
lightning-utilities==0.8.0
linear-operator==0.3.0
lit==16.0.0
lm-dataformat @ git+https://github.com/EleutherAI/lm_dataformat.git@4eec05349977071bf67fc072290b95e31c8dd836
lm-eval==0.3.0
log-symbols==0.0.14
lxml==4.9.2
lz4==4.3.2
Mako==1.2.4
Markdown==3.4.1
markdown-it-py==2.2.0
MarkupSafe==2.1.2
matplotlib==3.7.1
matplotlib-inline==0.1.6
mbstrdecoder==1.1.2
mdurl==0.1.2
minigrid==2.1.1
mistune==2.0.5
mlagents-envs==0.28.0
mlflow==1.30.0
modin==0.18.1
monotonic==1.6
mosaicml==0.12.1
mpmath==1.3.0
msal==1.18.0b1
msal-extensions==1.0.0
msgpack==1.0.5
msrest==0.7.1
msrestazure==0.6.4
mujoco==2.2.0
mujoco-py==2.1.2.14
multidict==6.0.4
multipledispatch==0.6.0
multiprocess==0.70.14
mxnet==1.8.0.post0
mypy-extensions==1.0.0
nbclassic==0.5.3
nbclient==0.7.2
nbconvert==7.2.10
nbformat==5.8.0
nest-asyncio==1.5.6
netifaces==0.11.0
networkx==3.0
nevergrad==0.4.3.post7
ninja==1.11.1
nltk==3.8.1
notebook==6.5.3
notebook_shim==0.2.2
numexpr==2.8.4
numpy==1.23.5
nvidia-cublas-cu11==11.10.3.66
nvidia-cuda-cupti-cu11==11.7.101
nvidia-cuda-nvrtc-cu11==11.7.99
nvidia-cuda-runtime-cu11==11.7.99
nvidia-cudnn-cu11==8.5.0.96
nvidia-cufft-cu11==10.9.0.58
nvidia-curand-cu11==10.2.10.91
nvidia-cusolver-cu11==11.4.0.1
nvidia-cusparse-cu11==11.7.4.91
nvidia-ml-py==11.495.46
nvidia-nccl-cu11==2.14.3
nvidia-nvtx-cu11==11.7.91
oauth2client==4.1.3
oauthlib==3.2.2
onnx==1.12.0
onnxruntime==1.14.1
open-spiel==1.2
openai==0.27.2
opencensus==0.11.2
opencensus-context==0.1.3
opencv-python==4.7.0.72
opentelemetry-api==1.1.0
opentelemetry-exporter-otlp==1.1.0
opentelemetry-exporter-otlp-proto-grpc==1.1.0
opentelemetry-exporter-otlp-proto-http==1.16.0
opentelemetry-proto==1.1.0
opentelemetry-sdk==1.1.0
opentelemetry-semantic-conventions==0.20b0
opt-einsum==3.3.0
optax==0.1.4
optuna==2.10.0
orbax==0.1.5
packaging==23.0
pandas==1.5.3
pandocfilters==1.5.0
paramiko==2.12.0
paramz==0.9.5
parso==0.8.3
pathspec==0.11.1
pathtools==0.1.2
pathvalidate==2.5.2
patsy==0.5.3
pbr==5.11.1
PettingZoo==1.22.1
pexpect==4.8.0
pickleshare==0.7.5
Pillow==9.4.0
pkginfo==1.9.6
pkgutil_resolve_name==1.3.10
platformdirs==3.1.1
plotly==5.13.1
pluggy @ file:///tmp/build/80754af9/pluggy_1648042571233/work
portalocker==2.7.0
prettytable==3.6.0
prometheus-client==0.13.1
prometheus-flask-exporter==0.22.3
promise==2.3
prompt-toolkit==3.0.38
property-manager==3.0
proto-plus==1.22.2
protobuf==3.20.3
psutil==5.9.4
ptyprocess==0.7.0
pure-eval==0.2.2
py-cpuinfo==9.0.0
py-spy==0.3.14
py3nvml==0.2.7
pyaml==21.10.1
pyarrow==11.0.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
pybind11==2.6.2
pycosat @ file:///croot/pycosat_1666805502580/work
pycountry==22.3.5
pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work
pycryptodomex==3.17
pydantic==1.10.6
pyDeprecate==0.3.2
pygame==2.1.2
pyglet==1.5.15
Pygments==2.14.0
PyJWT==2.6.0
pymoo==0.5.0
pymunk==6.2.1
PyNaCl==1.5.0
PyOpenGL==3.1.6
pyOpenSSL==23.0.0
pyparsing==3.0.9
pyperclip==1.8.2
pypng==0.20220715.0
pyro-api==0.1.2
pyro-ppl==1.8.4
Pyro4==4.82
pyrsistent==0.19.3
PySocks @ file:///tmp/build/80754af9/pysocks_1605305779399/work
pytablewriter==0.64.2
pytest==7.2.2
pytest-remotedata==0.3.2
python-dateutil==2.8.2
python-json-logger==2.0.7
pytorch-lightning==2.0.0
pytorch-ranger==0.1.1
pytz==2022.7.1
pytz-deprecation-shim==0.1.0.post0
PyWavelets==1.4.1
PyYAML==6.0
pyzmq==25.0.2
querystring-parser==1.2.4
ray @ file:///home/ray/ray-3.0.0.dev0-cp38-cp38-manylinux2014_x86_64.whl
ray-lightning==0.3.0
recsim==0.2.4
redis==3.5.3
regex==2022.10.31
requests==2.28.2
requests-oauthlib==1.3.1
requests-toolbelt==0.10.1
responses==0.18.0
RestrictedPython==6.0
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==12.0.1
rouge-score==0.1.2
rsa==4.9
ruamel.yaml @ file:///croot/ruamel.yaml_1666304550667/work
ruamel.yaml.clib @ file:///croot/ruamel.yaml.clib_1666302247304/work
s3transfer==0.6.0
sacrebleu==1.5.0
scikit-image==0.20.0
scikit-learn==1.2.2
scikit-optimize==0.9.0
scipy==1.10.1
segment-analytics-python==2.2.2
semantic-version==2.10.0
Send2Trash==1.8.0
sentencepiece==0.1.96
sentry-sdk==1.17.0
serpent==1.41
setproctitle==1.3.2
shortuuid==1.0.1
sigopt==7.5.0
six==1.16.0
smart-open==6.3.0
smmap==5.0.0
sniffio==1.3.0
soupsieve==2.4
spinners==0.0.24
SQLAlchemy==1.4.47
sqlitedict==2.1.0
sqlparse==0.4.3
stack-data==0.6.2
starlette==0.26.1
statsmodels==0.13.5
stevedore==5.0.0
SuperSuit==3.7.0
sympy==1.11.1
tabledata==1.3.1
tabulate==0.9.0
tblib==1.7.0
tcolorpy==0.1.2
tenacity==8.2.2
tensorboard==2.12.0
tensorboard-data-server==0.7.0
tensorboard-plugin-wit==1.8.1
tensorboardX==2.4.1
tensorflow-estimator==2.11.0
tensorflow-io-gcs-filesystem==0.31.0
tensorflow-probability==0.19.0
tensorstore==0.1.33
termcolor==2.2.0
terminado==0.10.1
tf-slim==1.1.0
tf2onnx==1.13.0
threadpoolctl==3.1.0
tifffile==2023.3.15
tiktoken==0.1.2
timm==0.4.5
tinycss2==1.2.1
tinyscaler==1.2.5
tokenizers==0.13.2
tomli==2.0.1
toolz @ file:///croot/toolz_1667464077321/work
torch==2.0.0
torch-optimizer==0.3.0
torchaudio==2.0.1
torchmetrics==0.11.4
torchvision==0.15.1
tornado==6.2
tqdm==4.65.0
tqdm-multiprocess==0.0.11
traitlets==5.9.0
transformers==4.27.2
triton==2.0.0
tune-sklearn==0.4.4
typeguard==2.13.3
typepy==1.3.0
typer==0.6.1
typing_extensions==4.5.0
tzdata==2022.7
tzlocal==4.3
ujson==5.7.0
uri-template==1.2.0
uritemplate==3.0.1
urllib3==1.26.15
uvicorn==0.21.1
verboselogs==1.7
virtualenv==20.21.0
wandb==0.13.4
wcwidth==0.2.6
webcolors==1.12
webencodings==0.5.1
websocket-client==1.5.1
Werkzeug==2.2.3
widgetsnbextension==4.0.5
wrapt==1.15.0
wurlitzer==3.0.3
xgboost==1.7.4
xgboost-ray==0.1.15
xmltodict==0.13.0
xxhash==3.2.0
y-py==0.5.9
yacs==0.1.8
yarl==1.8.2
ypy-websocket==0.8.2
zipp==3.15.0
zoopt==0.4.1
zstandard==0.20.0