Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error: BFloat16 Unsupported scalar when trying to execute across multiple GPUs with BFloat16 & 8-Bits #79

Closed
FTuma opened this issue Oct 18, 2022 · 2 comments

Comments

@FTuma
Copy link

FTuma commented Oct 18, 2022

I tried to run BLOOM distributed across multiple A100 GPUs with 8-Bit and using BFloat16 but ran into this error while trying to execute a slightly adjusted version of the example script:

===================================BUG REPORT===================================
Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link
================================================================================
CUDA SETUP: CUDA runtime path found: /datadrive/miniconda3/envs/petals/lib/libcudart.so
CUDA SETUP: Highest compute capability among GPUs detected: 8.0
CUDA SETUP: Detected CUDA version 113
CUDA SETUP: Loading binary /datadrive/miniconda3/envs/petals/lib/python3.9/site-packages/bitsandbytes/libbitsandbytes_cuda113.so...
Oct 18 09:52:07.795 [WARN] [/datadrive/repos/petals/src/client/remote_sequential.py.__init__:34] RemoteSequential is in active development; expect adventures
Some weights of DistributedBloomForCausalLM were not initialized from the model checkpoint at bloom-testing/test-bloomd-560m-main and are newly initialized: ['lm_head.word_embeddings.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Traceback (most recent call last):
  File "/datadrive/repos/petals/simple_test_script.py", line 17, in <module>
    remote_outputs = model.generate(inputs, max_length=100)
  File "/datadrive/miniconda3/envs/petals/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/datadrive/repos/petals/src/client/remote_generation.py", line 113, in generate
    hidden_state = sess.step(embs, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
  File "/datadrive/repos/petals/src/client/inference_session.py", line 200, in step
    outputs = session.step(inputs, prompts[self.chosen_spans[0].start : self.chosen_spans[0].end], **kwargs)
  File "/datadrive/repos/petals/src/client/inference_session.py", line 109, in step
    tensors=[
  File "/datadrive/repos/petals/src/client/inference_session.py", line 110, in <listcomp>
    serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
  File "/datadrive/miniconda3/envs/petals/lib/python3.9/site-packages/hivemind/compression/serialization.py", line 41, in serialize_torch_tensor
    return compression.compress(tensor, info, allow_inplace)
  File "/datadrive/miniconda3/envs/petals/lib/python3.9/site-packages/hivemind/compression/base.py", line 83, in compress
    array = tensor.detach().numpy()
TypeError: Got unsupported ScalarType BFloat16

The code of simple_example_script:

import torch
import torch.nn.functional as F
import transformers
from src import DistributedBloomForCausalLM

MODEL_NAME = "bloom-testing/test-bloomd-560m-main" #"bigscience/bloom-petals"
import os
initial_peer = os.getenv("initial_peer")
initial_peers = [initial_peer]  # e.g. ["/ip4/127.0.0.1/tcp/more/stuff/here"]
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model = DistributedBloomForCausalLM.from_pretrained(
  MODEL_NAME, initial_peers=initial_peers, low_cpu_mem_usage=True, torch_dtype=torch.float32
)  # this model has only embeddings / logits, all transformer blocks rely on remote servers

# model = model.to('cuda')
inputs = tokenizer("a cat sat", return_tensors="pt")["input_ids"]
remote_outputs = model.generate(inputs, max_length=100)
print(tokenizer.decode(remote_outputs[0]))  # "a cat sat in the back of the car,"

# "train" input embeddings by backprop through distributed transformer blocks
model.transformer.word_embeddings.weight.requires_grad = True
outputs = model.forward(input_ids=inputs)
loss = F.cross_entropy(outputs.logits.flatten(0, 1), inputs.flatten())
loss.backward()
print("Gradients (norm):", model.transformer.word_embeddings.weight.grad.norm())

Server launched via commands:

python -m cli.run_server bloom-testing/test-bloomd-560m-main --num_blocks 12 --torch_dtype bfloat16 --host_maddrs /ip4/0.0.0.0/tcp/31337 --load_in_8bit

python -m cli.run_server bloom-testing/test-bloomd-560m-main  --torch_dtype bfloat16 --host_maddrs /ip4/127.0.0.1/tcp/0 --load_in_8bit --initial_peers /ip4/127.0.0.1/tcp/31337/p2p/QmTHnjwKQFzvxrPesrSjtaL5eKUVdHfLsxV87vx8RFH21U --block_indices 12:24 --device cuda:1

Packages in the environment, have been installed via requirements.txt:

# packages in environment at /datadrive/miniconda3/envs/petals:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main  
_openmp_mutex             5.1                       1_gnu  
accelerate                0.10.0                   pypi_0    pypi
aiohttp                   3.8.3                    pypi_0    pypi
aiosignal                 1.2.0                    pypi_0    pypi
asttokens                 2.0.5              pyhd3eb1b0_0  
async-timeout             4.0.2                    pypi_0    pypi
attrs                     22.1.0                   pypi_0    pypi
backcall                  0.2.0              pyhd3eb1b0_0  
base58                    2.1.1                    pypi_0    pypi
bitsandbytes              0.34.0                   pypi_0    pypi
blas                      1.0                         mkl  
brotlipy                  0.7.0           py39h27cfd23_1003  
bzip2                     1.0.8                h7b6447c_0  
ca-certificates           2022.07.19           h06a4308_0  
certifi                   2022.9.24        py39h06a4308_0  
cffi                      1.15.1           py39h74dc2b5_0  
charset-normalizer        2.0.4              pyhd3eb1b0_0  
click                     8.1.3                    pypi_0    pypi
configargparse            1.5.3                    pypi_0    pypi
cryptography              37.0.1           py39h9ce1e76_0  
cudatoolkit               11.3.1               h2bc3f7f_2  
datasets                  2.5.2                    pypi_0    pypi
debugpy                   1.5.1            py39h295c915_0  
decorator                 5.1.1              pyhd3eb1b0_0  
dill                      0.3.5.1                  pypi_0    pypi
docker-pycreds            0.4.0                    pypi_0    pypi
entrypoints               0.4              py39h06a4308_0  
executing                 0.8.3              pyhd3eb1b0_0  
ffmpeg                    4.3                  hf484d3e_0    pytorch
filelock                  3.8.0                    pypi_0    pypi
freetype                  2.11.0               h70c0345_0  
frozenlist                1.3.1                    pypi_0    pypi
fsspec                    2022.8.2                 pypi_0    pypi
giflib                    5.2.1                h7b6447c_0  
gitdb                     4.0.9                    pypi_0    pypi
gitpython                 3.1.29                   pypi_0    pypi
gmp                       6.2.1                h295c915_3  
gnutls                    3.6.15               he1e5248_0  
grpcio                    1.49.1                   pypi_0    pypi
grpcio-tools              1.48.2                   pypi_0    pypi
hivemind                  1.1.1                    pypi_0    pypi
huggingface-hub           0.7.0                    pypi_0    pypi
humanfriendly             10.0                     pypi_0    pypi
idna                      3.3                pyhd3eb1b0_0  
intel-openmp              2021.4.0          h06a4308_3561  
ipykernel                 6.15.2           py39h06a4308_0  
ipython                   8.4.0            py39h06a4308_0  
jedi                      0.18.1           py39h06a4308_1  
jpeg                      9e                   h7f8727e_0  
jupyter_client            7.3.5            py39h06a4308_0  
jupyter_core              4.11.1           py39h06a4308_0  
lame                      3.100                h7b6447c_0  
lcms2                     2.12                 h3be6417_0  
ld_impl_linux-64          2.38                 h1181459_1  
lerc                      3.0                  h295c915_0  
libdeflate                1.8                  h7f8727e_5  
libffi                    3.3                  he6710b0_2  
libgcc-ng                 11.2.0               h1234567_1  
libgomp                   11.2.0               h1234567_1  
libiconv                  1.16                 h7f8727e_2  
libidn2                   2.3.2                h7f8727e_0  
libpng                    1.6.37               hbc83047_0  
libsodium                 1.0.18               h7b6447c_0  
libstdcxx-ng              11.2.0               h1234567_1  
libtasn1                  4.16.0               h27cfd23_0  
libtiff                   4.4.0                hecacb30_0  
libunistring              0.9.10               h27cfd23_0  
libwebp                   1.2.4                h11a3e52_0  
libwebp-base              1.2.4                h5eee18b_0  
lz4-c                     1.9.3                h295c915_1  
matplotlib-inline         0.1.6            py39h06a4308_0  
mkl                       2021.4.0           h06a4308_640  
mkl-service               2.4.0            py39h7f8727e_0  
mkl_fft                   1.3.1            py39hd3c417c_0  
mkl_random                1.2.2            py39h51133e4_0  
msgpack                   1.0.4                    pypi_0    pypi
multiaddr                 0.0.9                    pypi_0    pypi
multidict                 6.0.2                    pypi_0    pypi
multiprocess              0.70.13                  pypi_0    pypi
ncurses                   6.3                  h5eee18b_3  
nest-asyncio              1.5.5            py39h06a4308_0  
netaddr                   0.8.0                    pypi_0    pypi
nettle                    3.7.3                hbbd107a_1  
numpy                     1.23.1           py39h6c91a56_0  
numpy-base                1.23.1           py39ha15fc14_0  
openh264                  2.1.1                h4ff587b_0  
openssl                   1.1.1q               h7f8727e_0  
packaging                 21.3               pyhd3eb1b0_0  
pandas                    1.5.0                    pypi_0    pypi
parso                     0.8.3              pyhd3eb1b0_0  
pathtools                 0.1.2                    pypi_0    pypi
pexpect                   4.8.0              pyhd3eb1b0_3  
pickleshare               0.7.5           pyhd3eb1b0_1003  
pillow                    9.2.0            py39hace64e9_1  
pip                       22.2.2           py39h06a4308_0  
prefetch-generator        1.0.1                    pypi_0    pypi
promise                   2.3                      pypi_0    pypi
prompt-toolkit            3.0.20             pyhd3eb1b0_0  
protobuf                  3.20.3                   pypi_0    pypi
psutil                    5.9.2                    pypi_0    pypi
ptyprocess                0.7.0              pyhd3eb1b0_2  
pure_eval                 0.2.2              pyhd3eb1b0_0  
pyarrow                   9.0.0                    pypi_0    pypi
pycparser                 2.21               pyhd3eb1b0_0  
pydantic                  1.10.2                   pypi_0    pypi
pygments                  2.11.2             pyhd3eb1b0_0  
pymultihash               0.8.2                    pypi_0    pypi
pyopenssl                 22.0.0             pyhd3eb1b0_0  
pyparsing                 3.0.9            py39h06a4308_0  
pysocks                   1.7.1            py39h06a4308_0  
python                    3.9.13               haa1d7c7_1  
python-dateutil           2.8.2              pyhd3eb1b0_0  
pytorch                   1.12.1          py3.9_cuda11.3_cudnn8.3.2_0    pytorch
pytorch-mutex             1.0                        cuda    pytorch
pytz                      2022.4                   pypi_0    pypi
pyyaml                    6.0                      pypi_0    pypi
pyzmq                     23.2.0           py39h6a678d5_0  
readline                  8.1.2                h7f8727e_1  
regex                     2022.9.13                pypi_0    pypi
requests                  2.28.1           py39h06a4308_0  
responses                 0.18.0                   pypi_0    pypi
scipy                     1.9.2                    pypi_0    pypi
sentry-sdk                1.9.10                   pypi_0    pypi
setproctitle              1.3.2                    pypi_0    pypi
setuptools                63.4.1           py39h06a4308_0  
shortuuid                 1.0.9                    pypi_0    pypi
six                       1.16.0             pyhd3eb1b0_1  
smmap                     5.0.0                    pypi_0    pypi
sortedcontainers          2.4.0                    pypi_0    pypi
sqlite                    3.39.3               h5082296_0  
stack_data                0.2.0              pyhd3eb1b0_0  
tk                        8.6.12               h1ccaba5_0  
tokenizers                0.12.1                   pypi_0    pypi
torchaudio                0.12.1               py39_cu113    pytorch
torchvision               0.13.1               py39_cu113    pytorch
tornado                   6.2              py39h5eee18b_0  
tqdm                      4.64.1                   pypi_0    pypi
traitlets                 5.1.1              pyhd3eb1b0_0  
transformers              4.21.3                   pypi_0    pypi
typing_extensions         4.3.0            py39h06a4308_0  
tzdata                    2022c                h04d1e81_0  
urllib3                   1.26.11          py39h06a4308_0  
uvloop                    0.17.0                   pypi_0    pypi
varint                    1.0.2                    pypi_0    pypi
wandb                     0.13.4                   pypi_0    pypi
wcwidth                   0.2.5              pyhd3eb1b0_0  
wheel                     0.37.1             pyhd3eb1b0_0  
xxhash                    3.0.0                    pypi_0    pypi
xz                        5.2.6                h5eee18b_0  
yarl                      1.8.1                    pypi_0    pypi
zeromq                    4.3.4                h2531618_0  
zlib                      1.2.12               h5eee18b_3  
zstd                      1.5.2                ha4553b6_0

I just used the small version for debugging purposes, I need to distribute it across multiple GPUs since I intend to run the 176bn BLOOM version. I tried to naively just convert the tensor at that line to a supported DType but then another error occured somewhere else down the line.

Since I want to do Prompt Tuning on 8x 40GB A100s, I think I have to use BFloat16 & 8Bit or is there another solution/workaround with good performance?

@justheuristic
Copy link
Collaborator

Hi there! Will look into that later today (AOE) and try to reproduce. On the surface, we should not have bfloat16 at that stage, so it should be easy to fix. brb.

borzunov added a commit to learning-at-home/hivemind that referenced this issue Nov 28, 2022
This PR implements bfloat16 support for `CompressionType.NONE` and `CompressionType.BLOCKWISE_8BIT`.

This is important for the Petals client, see bigscience-workshop/petals#79
@borzunov
Copy link
Collaborator

Hi @FTuma!

Sorry for taking the long time to look into this. The issue should be fixed now (don't forget to pull the latest main commit before trying it out).

For the reference, here are the two PRs where we did that:

I will close this issue for now, but feel free to reopen it or make a new one if you run into other issues.

mryab pushed a commit to learning-at-home/hivemind that referenced this issue Nov 29, 2022
This PR implements bfloat16 support for `CompressionType.NONE` and `CompressionType.BLOCKWISE_8BIT`.

This is important for the Petals client, see bigscience-workshop/petals#79

(cherry picked from commit 1e4af43)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants