-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
bigscience/T0
multi-gpu inference exits with return code -9
#16616
Comments
Please try again with the exact code from https://huggingface.co/docs/transformers/main/main_classes/deepspeed#custom-deepspeed-zero-inference as I cleaned it up a bit more - I have just re-tested with it - it works just fine on 2x rtx 3090 gpus. But it's using the smaller
I'm using master/main version of transformers/deepspeed and pt-1.11. |
OK, I managed to crash my system with the 11B version with 2 gpus. Need to figure out cgroup v2 as I moved to Ubuntu 21.10 and my v1 setup no longer works. Meanwhile I figured out how to run a shell that will not any processes started from it use more memory than I told it to and thus not kill the host:
but since we have this huge checkpoint of 42GB I don't have enough RAM to load it twice in 2 processes. We have just added sharded checkpoints so need to switch T0 to it. And meanwhile I'm trying to figure out how to get this to run with nvme offload. I will update more once I have something running. |
Thanks for your help! I've tried to run the example at the link, and now I get another error, related to Ninja--full traceback below. This is an error I have seen before when trying to run the script I provided in my initial post. The errors seemed to alternate between the return code -9 and this Ninja error, without changing anything in the code. If the example works for you, I can't figure out what's going wrong on my end. Ninja is installed in my environment, and I am going to set up a new environment and see if that has better results.
|
What's the output of this? I included the output on my conda env:
Perhaps your try |
one of the deepspeed devs was able to reproduce your original error - seems to be related to the So until they figure it out the quick fix is not to use it ;) Instead use the
cc: @jeffra |
I don't see any output when I give this command, and Thanks for the I will say that I was able to launch T0 and get it working several times last week and early this week, so I'm not sure why the Ninja error is suddenly appearing. |
No output means it can't find it in There could be 2 issues:
Let's look at each case:
e.g. in my case:
So the Typically conda pushes that path into
|
I was able to finally get past the Ninja problem by force installing ( I also made a new environment and installed all the necessary packages. Here's the information for the new environment:
For both my original and new environment, I can get T0_3B to work on the Custom DeepSpeed ZeRO Inference example. However, the Custom DeepSpeed ZeRO Inference with the T0 model still finishes with exit code -9 and now mentions
Something else to note is that I was able to successfully run T0 and get output last week, around March 31st. In that case, I had two processes running at the same time, sending the same example to both processes, and output would be generated. When I sent different examples to each process, it appeared that rank=0 would finish before rank=1, and the input at rank=1 would be hanging. |
glad to hear you figured out the traceback you pasted if from the launcher, not the actual program. there are 2 independent programs, the launcher starts your actual program and the traceback is that it detected that your program has failed, but do you have the traceback from your program?
I understand the symptom. It means that the gpus synchronisation code in So we need to figure out why the sync didn't kick in. The sync is enabled here:
which tells me that
Could you insert:
before:
and see if it reports: "Deepspeed 3 is enabled: True" |
Further, for now please switch to this branch of You can install it directly like so:
Please install this branch and then try again. Note: this branch is a bit slow at the moment as prefetch is currently not working, but it'll get fixed once the Deepspeed team is back to work. So it'll be faster once it's enabled again. |
Here is the nvme offload version that I tested with. Works great even with 1x or 2x tiny gpu - I didn't see more than 3GB used on each, but it's slow of course.
|
That is the full output when I run the program to use the T0 model. There are a few additional lines above what I posted, but there is no additional traceback info. I'll post the full output here (this is before I executed
I've switched over to this branch.
When running the zero inference example with T0_3B, the program outputs "Deepspeed 3 is enabled: True" (twice) and successfully returns predictions for the two examples. When I try to use the same zero inference example with T0, I get the same error as above (still without any extra traceback info). It does not output "Deepspeed 3 is enabled: True", so it must be exiting the program before it reaches that line.
|
I tried to run this example and got another error when running it as
Now that I've switched to a new branch of If so, here's the output I get when running that example with
I saw your post DeepSpeed #1037 saying that I might need to do
I'm going to check if this is just a permissions issue--hopefully that will fix it. |
Yes, you need to If for any reason you have an issue with installing libaio system-wide here is how to install it via conda if you use the latter: deepspeedai/DeepSpeed#1890 So let's try the nvme solution once you installed wrt to failing to start with T0, I wonder if your kernel kills the program because it tries to use 4x cpu memory (over 3B that works) and on 2 gpus that's a huge amount of additional memory (64GB more). Perhaps something gets logged in How much cpu memory do you have on this host? Perhaps, try the low_cpu_mem approach:
but the Deepspeed should really have a parameter that defines how much CPU memory can be used. |
I was able to install I'm getting a permission error related to nvme:
Here's the output of
I tried the low memory approach, and I got a message saying that
|
Apologies if it wasn't obvious you were meant to edit the path to some path on your filesystem. It just happened to be
So 128MB of CPU RAM. When dealing with huge models it always helps to have some swap memory, which extends your effective CPU memory.
So yes, as expected your system kills the process, as it consumes too much CPU memory.
Ah, yes, sorry, that is still a work in progress. I will need to work on having Hmm, staggered loading should overcome this issue as well, basically having the 2nd instance of the script insert a delay before
Actually the staggering most likely won't work, since deepspeed's |
Here is a recipe to add swap memory, of course, edit the path and the desired amount of GBs
|
now use "t0-sharded" as a model name (at some point we will have a sharded version on the hub) you can shard it into even smaller chunks, say of 5GB:
I'd say do the latter for this experiment. and of course
With these 2 fixes we will still need Please let me know if this unblocks you. another way: with Deepspeed nvme + cpu offload 1 gpu should be enough! as you only need to be able to load a single largest layer and if you don't care for the parallel input processing you're not gaining anything from 2 gpus anyway when using nvme offload (I think, I haven't measured, so I can be wrong). and I still want to try to work out |
@stas00, thank you so much for your help! I'm answering for @gportill since we were working on this issue together. Summary of what worked:
Full working example:This example was modified from #15399 (comment) and assumes all of the "summary of what worked" steps were taken. #!/usr/bin/env python
# This script demonstrates how to use Deepspeed ZeRO in an inference mode when one can't fit a model
# into a single GPU
#
# 1. Use 1 GPU with CPU offload
# 2. Or use multiple GPUs instead
#
# First you need to install deepspeed: pip install deepspeed
#
# Here we use a 3B "bigscience/T0_3B" model which needs about 15GB GPU RAM - so 1 largish or 2
# small GPUs can handle it. or 1 small GPU and a lot of CPU memory.
#
# To use a larger model like "bigscience/T0" which needs about 50GB, unless you have an 80GB GPU -
# you will need 2-4 gpus. And then you can adapt the script to handle more gpus if you want to
# process multiple inputs at once.
#
# The provided deepspeed config also activates CPU memory offloading, so chances are that if you
# have a lot of available CPU memory and you don't mind a slowdown you should be able to load a
# model that doesn't normally fit into a single GPU. If you have enough GPU memory the program will
# run faster if you don't want offload to CPU - so disable that section then.
#
# To deploy on 1 gpu:
#
# deepspeed --num_gpus 1 t0.py
# or:
# python -m torch.distributed.run --nproc_per_node=1 t0.py
#
# To deploy on 2 gpus:
#
# deepspeed --num_gpus 2 t0.py
# or:
# python -m torch.distributed.run --nproc_per_node=2 t0.py
# Imports
from transformers import AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM
from transformers.deepspeed import HfDeepSpeedConfig
import deepspeed
import os
# To avoid warnings about parallelism in tokenizers
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import torch
from argparse import ArgumentParser
#################
# DeepSpeed Config
#################
def generate_ds_config(args):
"""
ds_config notes
- enable bf16 if you use Ampere or higher GPU - this will run in mixed precision and will be
faster.
- for older GPUs you can enable fp16, but it'll only work for non-bf16 pretrained models - e.g.
all official t5 models are bf16-pretrained
- set offload_param.device to "none" or completely remove the `offload_param` section if you don't
- want CPU offload
- if using `offload_param` you can manually finetune stage3_param_persistence_threshold to control
- which params should remain on gpus - the larger the value the smaller the offload size
For indepth info on Deepspeed config see
https://huggingface.co/docs/transformers/main/main_classes/deepspeed
keeping the same format as json for consistency, except it uses lower case for true/false
fmt: off
"""
config = AutoConfig.from_pretrained(args.model_name)
world_size = int(os.getenv("WORLD_SIZE", "1"))
model_hidden_size = config.d_model
# batch size has to be divisible by world_size, but can be bigger than world_size
train_batch_size = args.batch_size * world_size
config = {
"fp16": {
"enabled": False
},
"bf16": {
"enabled": False
},
"zero_optimization": {
"stage": 3,
"offload_param": {
"device": args.offload,
"nvme_path": args.nvme_offload_path,
"pin_memory": True,
"buffer_count": 6,
"buffer_size": 1e8,
"max_in_cpu": 1e9
},
"aio": {
"block_size": 262144,
"queue_depth": 32,
"thread_count": 1,
"single_submit": False,
"overlap_events": True
},
"overlap_comm": True,
"contiguous_gradients": True,
"reduce_bucket_size": model_hidden_size * model_hidden_size,
"stage3_prefetch_bucket_size": 0.1 * model_hidden_size * model_hidden_size,
"stage3_max_live_parameters": 1e8,
"stage3_max_reuse_distance": 1e8,
"stage3_param_persistence_threshold": 10 * model_hidden_size
},
"steps_per_print": 2000,
"train_batch_size": train_batch_size,
"train_micro_batch_size_per_gpu": 1,
"wall_clock_breakdown": False
}
return config
#################
# Helper Methods
#################
def parse_args():
"""Parse program options"""
parser = ArgumentParser()
parser.add_argument("--model-name", default="bigscience/T0", help="Name of model to load.")
parser.add_argument("--offload", choices=["nvme", "cpu", "none"], default="none",
help="DeepSpeed optimization offload choices for ZeRO stage 3.")
parser.add_argument("--nvme-offload-path", default="/tmp/nvme-offload",
help="Path for NVME offload. Ensure path exists with correct write permissions.")
parser.add_argument("--batch-size", default=1, help="Effective batch size is batch-size * # GPUs")
return parser.parse_args()
#################
# Main
#################
# Distributed setup
local_rank = int(os.getenv("LOCAL_RANK", "0"))
torch.cuda.set_device(local_rank)
deepspeed.init_distributed()
args = parse_args()
ds_config = generate_ds_config(args)
# fmt: on
# next line instructs transformers to partition the model directly over multiple gpus using
# deepspeed.zero.Init when model's `from_pretrained` method is called.
#
# **it has to be run before loading the model AutoModelForSeq2SeqLM.from_pretrained(model_name)**
#
# otherwise the model will first be loaded normally and only partitioned at forward time which is
# less efficient and when there is little CPU RAM may fail
dschf = HfDeepSpeedConfig(ds_config) # keep this object alive
# Special version of T0
revision = None
if args.model_name in ["bigscience/T0", "bigscience/T0pp"]:
revision = "sharded"
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name, revision=revision)
# initialise Deepspeed ZeRO and store only the engine object
ds_engine = deepspeed.initialize(model=model, config_params=ds_config)[0]
ds_engine.module.eval() # inference
# Deepspeed ZeRO can process unrelated inputs on each GPU. So for 2 gpus you process 2 inputs at once.
# If you use more GPUs adjust for more.
# And of course if you have just one input to process you then need to pass the same string to both gpus
# If you use only one GPU, then you will have only rank 0.
rank = torch.distributed.get_rank()
if rank == 0:
text_in = "Is this review positive or negative? Review: this is the best cast iron skillet you will ever buy"
elif rank == 1:
text_in = "Is this review positive or negative? Review: this is the worst restaurant ever"
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
inputs = tokenizer.encode(text_in, return_tensors="pt").to(device=local_rank)
# synced_gpus (bool, optional, defaults to False) —
# Whether to continue running the while loop until max_length (needed for ZeRO stage 3) model_kwargs —
# Additional model specific keyword arguments will be forwarded to the forward function of the model.
# If model is an encoder-decoder model the kwargs should include encoder_outputs.
with torch.no_grad():
outputs = ds_engine.module.generate(inputs, synced_gpus=True)
text_out = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"rank{rank}:\n in={text_in}\n out={text_out}\n") And the following code to run: export CUDA_LAUNCH_BLOCKING=0
export OMP_NUM_THREADS=1
python -m torch.distributed.run --nproc_per_node=2 T0_inference.py |
That's a really neat summary and code parametrization, @AADeLucia - great work! Just to add that with the sharded model it's now possible to infer T0 (42GB) and other similar models in fp32 using just 2x 24GB gpus, w/ deepspeed w/o any offload. But if you have smaller GPUs, or just one GPU or larger models then the above script allows you to offload to cpu RAM if you have lots of it and if not so much to an NVMe device - each making the performance progressively slower. And once:
|
Environment info
transformers
version: 4.17.0.dev0Who can help
Library:
Information
Model I am using: T0
The problem arises when using:
The tasks I am working on is:
To reproduce
I want to load T0 across two 24GB GPUs with DeepSpeed in order to run inference. I followed the example code given here in issue #15399.
When running the code below, after the model says
finished initializing model with 11.14B parameters
, it quits without outputting a model response. It does not give an error or traceback, just a return code of -9:Here is the code. Run with
deepspeed --num_gpus 2 <script.py>
Expected behavior
T0 should load across 2 GPUs, generate an answer, and then quit.
The text was updated successfully, but these errors were encountered: