-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[deepspeed] bigscience/T0*
multi-gpu inference with ZeRO
#15399
Comments
My apologies, it looks like I wrote wrong instructions for the non-HF Trainer case here: https://huggingface.co/docs/transformers/master/main_classes/deepspeed#nontrainer-deepspeed-integration - is that where you found this code or in another place. I'm asking so that we ensure it's fixed everywhere. It should be just |
ok, so indeed the import was wrong. I will fix the doc at https://huggingface.co/docs/transformers/master/main_classes/deepspeed#nontrainer-deepspeed-integration => #15400 But where did you take the rest of the code from? it can't possibly work. You may want to look into using / adapting https://github.com/huggingface/transformers/blob/master/examples/pytorch/text-generation/run_generation.py - I see it doesn't currently support t5 models. @VictorSanh, you were working on a multi-gpu generation with t0 models, what's the latest incarnation of the code that you were using if you don't mind sharing. I think it was with Deepspeed-Inference, right? Thanks. Perhaps |
That works! Now running into a different issue, figuring out the default config arguments to change.
That was the only place I found that line.
Which part of the code can't work? The T0pp/T0_3B is just from the model card: https://huggingface.co/bigscience/T0pp Updated code: """
Example code to load a PyTorch model across GPUs
"""
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers.deepspeed import HfDeepSpeedConfig
import deepspeed
import pandas as pd
import torch
import pdb
import os
seed = 42
torch.manual_seed(seed)
if __name__ == "__main__":
# must run before instantiating the model
# ds_config is deepspeed config object or path to the file
ds_config = "ds_config_zero3_gpu.json"
dschf = HfDeepSpeedConfig(ds_config) # keep this object alive
model_name = "bigscience/T0_3B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
engine = deepspeed.initialize(model=model, config_params=ds_config)
inputs = tokenizer.encode(
"Is this review positive or negative? Review: this is the best cast iron skillet you will ever buy",
return_tensors="pt")
outputs = model.generate(inputs)
print(tokenizer.decode(outputs[0])) I moved the config file outside because I was getting weird errors: {
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 3,
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": 1,
"stage3_prefetch_bucket_size": 1,
"stage3_param_persistence_threshold": 1,
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_fp16_weights_on_model_save": true
},
"gradient_accumulation_steps": 1,
"gradient_clipping": 0,
"steps_per_print": 2000,
"train_batch_size": 1,
"train_micro_batch_size_per_gpu": 1,
"wall_clock_breakdown": false
} New error: self._configure_train_batch_size()
File "/home/aadelucia/miniconda3/envs/fda_cersi_tobacco/lib/python3.8/site-packages/deepspeed/runtime/config.py", line 1050, in _configure_train_batch_size
self._batch_assertion()
File "/home/aadelucia/miniconda3/envs/fda_cersi_tobacco/lib/python3.8/site-packages/deepspeed/runtime/config.py", line 997, in _batch_assertion
assert train_batch == micro_batch * grad_acc * self.world_size, (
AssertionError: Check batch related parameters. train_batch_size is not equal to micro_batch_per_gpu * gradient_acc_step * world_size1 != 1 * 1 * 2
[2022-01-28 22:06:24,208] [INFO] [launch.py:178:sigkill_handler] Killing subprocess 643653 Now just playing with the arguments... I'm not even training, I just want to run inference. |
Looking in the Every time I change self._batch_assertion()
File "/home/aadelucia/miniconda3/envs/fda_cersi_tobacco/lib/python3.8/site-packages/deepspeed/runtime/config.py", line 997, in _batch_assertion
assert train_batch == micro_batch * grad_acc * self.world_size, (
AssertionError: Check batch related parameters. train_batch_size is not equal to micro_batch_per_gpu * gradient_acc_step * world_size2 != 1 * 1 * 1 |
Inference is a relatively new thing, I think until now most work was done on training, so please bear with us. Lots of tech is being developed as we speak and it's being polished to be super easy and fast. Until then, let's focus on something that works now. So this works with gpt2
run:
With a small adjustment it works with t5:
run:
but if I switch the t5 model in the example above to "bigscience/T0_3B" it generates gibberish under Deepspeed but works fine w/o Deepspeed.
This is very puzzling |
OK, I figured out the culprit - the model breaks when run under fp16! like many other bf16-pretrained models - most t5 models have this issue. Here are 2 possible solutions:
(and as of this moment deepspeed@master is needed to use bf16 - they will make a new release any day now)
Now you're running in fp32 (more memory). So this works with either of the 2 fixes from above:
run:
Note that in the code above I pass The ds config needs more work to become efficient if you plan to use this in production or something where you care for speed. Since you're not using the HF Integration you will have to put the right numbers together, hint: transformers/src/transformers/deepspeed.py Lines 261 to 264 in 16d4acb
If you don't care to shave a few %s off, then leave the above as is and it'll use the Deepspeed untuned defaults. Please let me know if you're able to see it working for yourself. If you have any questions please ask. If all is satisfactory you may close this Issue. As I said in the previous comment the whole inference experience should get much much better really soon now. |
It's working! Thank you SO MUCH!!! I did have to use all the space-saving tips (bf16 and changing the defaults) because I want the entire model on GPU without off-loading parameters to CPU. Here is the completed code: """
Example code to load a PyTorch model across GPUs
"""
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers.deepspeed import HfDeepSpeedConfig
import deepspeed
import torch
import pdb
import os
seed = 42
torch.manual_seed(seed)
local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))
model_hidden_size = 4096 # this is hard-coded to T0pp
ds_config = {
"fp16": {
"enabled": False,
},
"bf16": {
"enabled": True,
},
"zero_optimization": {
"stage": 3,
"overlap_comm": True,
"contiguous_gradients": True,
"reduce_bucket_size": model_hidden_size * model_hidden_size,
"stage3_prefetch_bucket_size": 0.9 * model_hidden_size * model_hidden_size,
"stage3_param_persistence_threshold": 10 * model_hidden_size
},
"steps_per_print": 2000,
"train_batch_size": world_size,
"train_micro_batch_size_per_gpu": 1,
"wall_clock_breakdown": False
}
if __name__ == "__main__":
# must run before instantiating the model
# ds_config is deepspeed config object or path to the file
dschf = HfDeepSpeedConfig(ds_config) # keep this object alive
model_name = "bigscience/T0pp"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
model.eval()
engine = deepspeed.initialize(model=model, config_params=ds_config, optimizer=None, lr_scheduler=None)
text = "Is this review positive or negative? Review: this is the best cast iron skillet you will ever buy"
inputs = tokenizer.encode(text, return_tensors="pt").to(device=local_rank)
outputs = model.generate(inputs)
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
print(tokenizer.decode(outputs[0], skip_special_tokens=True)) |
So glad it worked for you, Alexandra. But you can do better. Currently, with this code each gpu processes the same input, i.e. duplicated effort, but you can do parallel processing at 0 extra cost. You just need to change the end to something like:
(the code is untested) You can of course do that for more than 2 gpus, and each gpu will handle its own unique input. And of course, you can do batches too if you have enough memory left. Beware the multiple-processes prints tend to interleave - so you can use the following hack to overcome this issue: This Issue would be a good example for the DYI Deepspeed integration inference docs. |
Also please add the distributed init, so that the logging knows to not repeat the same logs for more than 1 gpu:
|
OK, here is a much improved program which also integrates some enhancements from @VictorSanh's work. This script can now handle both cpu offload and/or multiple gpu. e.g. I can process "bigscience/T0_3B" on a 8GB GPU no problem. I added a bunch of notes before and through the code - please let me know if anything is unclear or missing:
Let's take it for a run:
|
I tested the provided script with [2022-01-30 20:09:26,784] [INFO] [launch.py:178:sigkill_handler] Killing subprocess 861020
[2022-01-30 20:09:26,784] [INFO] [launch.py:178:sigkill_handler] Killing subprocess 861021
[2022-01-30 20:09:26,784] [ERROR] [launch.py:184:sigkill_handler] ['/home/aadelucia/miniconda3/envs/fda_cersi_tobacco/bin/python', '-u', 'hf_zero_example.py', '--local_rank=1'] exits with return code = -9 I tried with only using And when I change rank0:
in=Is this review positive or negative? Review: this is the best cast iron skillet you will ever buy
out=Positive
[2022-01-30 18:42:22,779] [INFO] [launch.py:210:main] Process 775166 exits successfully. Sun Jan 30 19:34:16 2022
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.86 Driver Version: 470.86 CUDA Version: 11.4 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA GeForce ... Off | 00000000:01:00.0 Off | N/A |
| 0% 56C P8 24W / 350W | 70MiB / 24260MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 1 NVIDIA GeForce ... Off | 00000000:21:00.0 Off | N/A |
| 49% 65C P2 155W / 350W | 14083MiB / 24268MiB | 100% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 1491 G /usr/lib/xorg/Xorg 56MiB |
| 0 N/A N/A 1696 G /usr/bin/gnome-shell 9MiB |
| 1 N/A N/A 1491 G /usr/lib/xorg/Xorg 4MiB |
| 1 N/A N/A 775167 C ..._cersi_tobacco/bin/python 14033MiB |
+-----------------------------------------------------------------------------+ And then with rank = torch.distributed.get_rank()
text_in = "Is this review positive or negative? Review: this is the best cast iron skillet you will ever buy"
tokenizer = AutoTokenizer.from_pretrained(model_name)
inputs = tokenizer.encode(text_in, return_tensors="pt").to(device=local_rank)
with torch.no_grad():
outputs = ds_engine.module.generate(inputs)
text_out = tokenizer.decode(outputs[0], skip_special_tokens=True)
if rank == 0:
print(f"rank{rank}:\n in={text_in}\n out={text_out}") rank0:
in=Is this review positive or negative? Review: this is the best cast iron skillet you will ever buy
out=Positive
[2022-01-30 20:19:47,097] [INFO] [launch.py:210:main] Process 861553 exits successfully.
[2022-01-30 20:19:48,099] [INFO] [launch.py:210:main] Process 861552 exits successfully. Also,
Why does T0 need 50GB? The model plus the vocab/config files is ~42GB according to the card. I have two 24GB GPUs (48GB total) and I was hoping to put the entire model on both GPUs without offloading to CPU, but I seem to run out of space, even with the ZeRO optimizations you suggested. Is there really ~8GB of extras loaded onto the GPU? (I only expect to fit batch size of 1) Summary:
|
Ah right! I keep forgetting this not HF Trainer integration, so everything has to be done manually. In the HF Trainer integration it's all done already and you don't need to think about any of this. Please change the code to:
I fixed the example above. All gpus have to work in sync even if their output is shorter than other gpu, which is what may happen when inputs are different. W/o sync if one gpu finished early the whole ensemble hangs because each gpu has a shard of a model and other gpus depend on it. and the gpus gather the missing shards in pre-forward call. So if one gpu stopped, the rest can't continue. when you use the same input, it automatically syncs the gpus because all gpus finish at the same time.
We are now talking Inference only:
For training please see: |
Thank you, it works with T0pp now! And thanks for the memory analysis. My confusion was because I kept getting this cryptic error when I ran the provided script without offloading to CPU: RuntimeError: CUDA error: CUBLAS_STATUS_NOT_INITIALIZED when calling `cublasCreate(handle)`
[2022-01-31 17:33:28,620] [INFO] [launch.py:178:sigkill_handler] Killing subprocess 883030
[2022-01-31 17:33:28,621] [INFO] [launch.py:178:sigkill_handler] Killing subprocess 883031
[2022-01-31 17:33:28,621] [ERROR] [launch.py:184:sigkill_handler] ['/home/aadelucia/miniconda3/envs/fda_cersi_tobacco/bin/python', '-u', 'hf_zero_example.py', '--local_rank=1'] exits with return code = 1 My guess is it is a weird manifestation of an out of memory error, since the program works just fine when I run the program with More reading on
|
Glad to hear it's not hanging any longer.
Enabling it however shouldn't impact the overall memory usage, so if everything works when you enable How do I reproduce this issue? |
bigscience/T0*
multi-gpu inference with ZeRO
I adjusted the title of this Issue and re-opened this Issue since we are clearly still working through it. |
I have been trying this @stas00 . I am using 1 A100 GPU, but it's pretty slow, I tried batch size 10, but don't see much difference . Any idea how to improve inference size? it's using only 4285MiB / 40536MiB |
Please give me more context, @tuhinjubcse. Are you trying the script I shared above? So 40GB A100 - got that part. for T0_3B disable cpu offloading (instructions in the script) and it will be much faster. the bigger T0-ones I don't think will fit into a single GPU w/o offload, unless you use full bf16, instead of mixed precision, so basically cast your model and input to at bf16 you will only need 22GB for the weights and then some for activations. Let me know how it went. I don't think Deepspeed can do non-mixed precision. Let me ask if this can be done somehow |
I am using T0pp and yes trying the script shared above. Happy to use multiple GPUs but do you have any idea if that will make it faster? I didn't get this part. what is the logic of assigning one prompt to rank 0 and another to rank 1
|
disable cpu offload.
If you get a chance please read https://huggingface.co/docs/transformers/parallelism#zero-data-parallelism - you will hopefully see that each gpu will assemble a full layer and run the inputs as if it had the full model all along. So when you use 2 gpus, you can process 2 unrelated batches at once. With 4 gpus, 4, etc. If you're using a single batch and replicate it to all gpus, they will each calculate an identical output. so your efficiency is 1/n_gpus. You're asking how to make things run faster. If you have 4 gpus give each gpu a different input and you have 4x speed up! whoah! The more gpus you use the bigger batch size you can use. So that will give you further speedup. the example above is a demo, so you want to switch to batches and not single inputs. |
Then run the script you provided with T0pp |
I'm trying to think where to find a similar setup, as I only have 1x RTX 3090, I've been trying to buy a 2nd one for a long time and it's just not available to buy :( |
Is there anything I can try in the meantime? |
I tried to reproduce on 2x A100 40GB and there is no problem there. The program completes w/o hanging or errors. |
Thanks @stas00 for the great example. I'm trying to run T0_3B inference on a single A10 GPU, so I don't need ZeRO here or multi-GPU inference. Using your suggestion to run bf16 inference without deepspeed, I'm casting both the model and inputs to bfloat16, but PyTorch returns The rough flow is as follows:
Your help would be greatly appreciated. I've tried it with casting only the model to bfloat16 but not the inputs -- it runs, but there's no speedup over FP32. GPU utilization is also a lot lower with BF16 model weights vs. FP32. Hardware: Nvidia A10, 24GB VRAM |
|
Thank you for your response @stas00. Yeah, by casting I meant the approach you described in part 3 of your answer -- And my apologies, next time I'll start a new issue. |
Hello, I was able to run the code stas00 mentioned above. The description of the task:
Question:
|
You just get each rank to generate its unique sequence. See the simple example for 2 gpus here: https://huggingface.co/docs/transformers/main/main_classes/deepspeed#custom-deepspeed-zero-inference Specifically this part:
If you run that example verbatim, you will see that each rank (gpu) will print its own generated answer. I hope you can see that you code remains exactly the same. You just make the input different for each rank.
How do you imagine to use both together? You can use CPU offload if the gpus are too few, but that would make the speed slower. Please see: https://huggingface.co/docs/transformers/main/main_classes/deepspeed#deployment-with-one-gpu But overall, no, it won't be more efficient. |
Additionally, Deepspeed has recently released a new product called Deepspeed-Inference which speeds up inference by splitting the processing over multiple GPUs using Tensor Parallelism. And it can even handle quantized int8 input and thus requiring half the resources at a cost of course of slower execution. See https://github.com/bigscience-workshop/Megatron-DeepSpeed/tree/main/scripts/bloom-inference-scripts#deepspeed-inference though it's a temporary home - most likely the scripts will move to another location soon. The demo scripts are written for BLOOM but can be adapted to other models. If you run into problems please ask directly at Deepspeed Issues as this is not my code (just the bloom demos are mine ;). |
Thank you, Stas. This is exactly my question. Let's say I have 40 GPUs but just 24 inputs in one document. Would that mean that I can take advantage only of 24 GPUs at a time? It is particularly an issue since the number of inputs in each document varies. If I assign each input for each GPU using the example you mentioned, I will run into the issue that while processing some documents, I will have many idle GPUs (if I have 40 GPUs and 24 inputs in the particular document).
Or maybe there is a smarter way to iterate the GPUs over the many inputs in many documents. Thank you for your help with this issue. |
you need to understand how ZeRO works - all gpus must always work in sync, so you never have idling gpus. I suggest to perhaps read the main paper: https://arxiv.org/abs/1910.02054 you can do a single stream and then all other gpus will process it too, you can send unique streams - so you can have 24 unique streams and the rest will get whatever input and you can ignore the results. but again all gpus have to work in sync, because each gpu carries a unique shard of weights, and other gpus can't continue w/o it. I think most likely you will want to research Deepspeed-Inference which is faster than ZeRO-Inference and there you don't need to bother with multiple streams, as it always has just a single stream - you just feed it a large batch-size instead and of course you can change its size on every Deepspeed-Inference also uses custom fused kernels, which makes it super-fast. If some model isn't supported you can ask the Deepspeed team to add the support - it should be pretty quick. You can see the benchmark results here (albeit for a much larger model - bloom-176b) If you're planning to build a server solution, there are several WIP solutions as well, one is: |
Thank you, Stas. I will follow your advice and do more reading. |
Hello Stas, It's me again. This is very unfortunate. Since this bottleneck deprives of so many benefits of parallelization. For example, if I have 4 GPUs with 32GB RAM each, I still can't run T0pp if I have just 50 GB of CPU RAM total. Are there any ways around this issue? |
Yes, there is. You can pre-shard the model weights into small shards. You're in luck since I already did that for T0pp: https://huggingface.co/bigscience/T0pp/tree/sharded so all you need to do is please note that the revision is named All new models starting from a few months ago are added as sharded into 5-10GB shards by default, the old ones - I sharded many - you can see the status here: #16884 And you can further reshard those models into even smaller chunks, now you can have little CPU RAM and concurrently load into many gpus no problem. e.g. to 5GB
|
Wonderful news! Thank you very much. Is there any document/readme which can educate me on how to choose the right parameter choice (5GB, 10GB) and how this would affect the speed of the inference? Do I understand correctly that according to your calculations in the above-mentioned issue, to run the inference with T0pp would take (5GB (max_shard_size) * #-of-GPUs) CPU RAM? |
The size of the shard is only important at the loading time for many concurrent processes. The formula would be roughly this:
so say 4 gpus and 5gb shard:
so 40GB of additional CPU memory will be needed to load the model. the 2x is because at some point you have the But also look into the very recent solutions which will be even faster than deepspeed zero: https://huggingface.co/blog/bloom-inference-pytorch-scripts - the article is written for BLOOM, but there is no reason why it shouldn't work for t0 models. Definitely for Accelerate as it's model agnostic. For Deepspeed-Inference which uses custom CUDA kernels - I haven't tried - if it doesn't work with the latter please ask at Deepspeed Issue - but these will be much faster solutions - unless you infer different streams for each gpu with deepspeed-zero - please read the article and you should be able to see what would work the best for you. If you try Deepspeed-Inference w/ T0 please report back success/failure I'd like to know. Thank you! |
@stas00 Hello Stas, I have been experimenting with this code you posted, and I got strange results. I was wondering if you have any thoughts/suggestions for me to improve the code, especially speed-wise. I ran my code with deepspeed using 4 and 1 V100 32 GB GPUs, respectively. To generate answers for the same exact questions, 1 GPU finishes the job in 70 minutes; using four GPUs takes 112 minutes!!! I use your code except for the last couple of lines, which I change to:
|
It's not strange at all. When using ZeRO 1 gpu will always be faster if you can fit the model in - this is because of the overhead of comms with multiple gpus which a 1-gpu setup doesn't have. This is about using the right tool for the right job. ZeRO was written for models that can't be fit into a single GPU. If you can use a single GPU use it ;) Also the following nuance is very important: a 4-gpu set up in your case generates 4 different outputs and 1 gpu only one, so the effective speed of 4 gpus is 1/4th of 112 minutes, so 28 minutes per gpu. Does it make sense? And there is a much faster Deepspeed-Inference solution that was released just recently https://huggingface.co/blog/bloom-inference-pytorch-scripts that will indeed make your 4 gpus speed up the inference by much and faster than 1 gpu. The Accelerate solution is also likely to be faster or on par with ZeRO when you feed the latter unique streams. I haven't tried with this particular model to tell for sure. |
Oh, I see. Although I knew the purpose of Zero is to fit big models, I assumed it would also help with the speed!! I do not understand the nuance you mentioned, though; I generated the same number of answers in both the 4-GPU and 1-GPU setups. To make it more specific, to answer 60 questions, 4-GPU set up took 112 minutes while 1-GPU set up only took 70 minutes. I will try to implement the model using the document you mentioned. That's very helpful. |
I wasn't able to derive that this was the case - it's possible I missed that. If it is as you say it is then the overhead of comms is really big then and 4 gpus are indeed slower even with unique streams. When you have fast intranode connectivity like NVLink as compared to PCIe usually the comms overhead is lower and then compute dominates and gpus excel at what they do - fast results. when comms are slow then the gpus idle a lot - slow results. same goes for multiple nodes - one node is fast, more than one node is usually much slower since inter-node networks are usually slower than intra-node - but it's not always the case (e.g. NVSwitch can connect many nodes at almost the same speed as NVLink on one node). You can also watch your GPU utilization in nvidia-smi while processing in 1 vs 4 gpus - if it's always close to 100% then comms are super fast - if it's jumping between 0 and 100, then the overhead of getting the data around is large. nvidia-smi doesn't show the exact compute util, but it's a good enough indication. If for example you try NVMe offload gpu will be heavily under-utilized since disc IO will be slow. |
Hello Stas, Here is my report. The tricky part is that I had to downgrade my version of "transformers" to version 4.21.3. Otherwise, it doesn't work.
However, I struggle with integrating "sharded revision" into this code. I assume a pipeline can't be used for it, right? |
Thank you for sharing the outcome, @archieCanada I'd recommend opening an Issue at https://github.com/microsoft/DeepSpeed/issues so that it can be fixed on their side. Besides your repro code please include the traceback of the failure you had with the latest transformers. I wonder if they need to add tests to ensure.
I'm not quite sure, I use explicit code rather than pipelines so that I have full control over all parts. I think you could open a feature request to have To hack around it you could download the revision you want and load it locally using the path to the clone rather than the model name. |
@stas00 I see that you have mentioned |
That's the thing about ZeRO - all participating in sharding gpus must always work in sync. If one gpu finished early for some reason it must continue running You can see how I implemented it in You can also use |
Hello Stas and community, I try to implement the sharded version of the T0pp model, but I fail to do so for reasons I don't understand.
I have the following resources: 2 GPUs (Tesla V100-SXM2-32GB) According to my understanding these resources should be enough. But the system crushes already on the line: |
most likely you have too little cpu memory, try with normally with unsharded model you need 2x model size of CPU memory to load it. |
Hello Stas, I probably don't understand everything, because I am missing your point. Here is my situation: What I have tried:
In both cases, I get an Error: Probably this is not the right way of loading the sharded model. Should I then myself save the model as sharded using the following algorithm?:
Please, correct me where I am wrong. |
I was replying to your comment of getting your program killed. I think now I perhaps understand what you are struggling with. there is no such model as Here is what you can do:
if 10GB is too big then try next to make your own sharded model:
this step will require 2x model-size cpu memory and then a bit more. So 100GB of CPU memory should be enough and then use the resulting model like so:
python -c 'from transformers import AutoModelForSeq2SeqLM; you will of course will have to adapt the upcase name to your situation None of these needs a GPU. Please let me know if any of the 3 worked for you |
Hello Stas, I tried method 2) you described above.
My resources: I tried two different options:
Both times I received the following error:
|
but really as I suggested originally you should post an Issue at https://github.com/microsoft/DeepSpeed/issues and tag RezaYazdaniAminabadi - DeepSpeed-Inference isn't something that is integrated into Deepspeed is a project that has multiple related frameworks - we have only the ZeRO project integrated into HF Trainer and Accelerate. DeepSpeed-Inference requires no integration. |
Environment info
transformers
version: 4.17.0.dev0Who can help
Models:
(I'm actually trying to use T0pp but T5 is close enough)
Library:
Information
Model I am using (Bert, XLNet ...): T0pp / T0_3B
The problem arises when using:
The tasks I am working on is:
To reproduce
I want to load T0pp across 2 24GB GPUs and only run inference. I know Deepspeed wit zeRO stage 3 is the way to go for this from reading documentation. I am following the HuggingFace example here to use Deepspeed without a
Trainer
object.The error I get is
My code:
Run with
CUDA_VISIBLE_DEVICES="0,1" deepspeed <script.py>
Expected behavior
T0pp (or T0_3B) to load across 2 GPUs, generate an answer, and then quit.
The text was updated successfully, but these errors were encountered: