Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Inference a large model using NVMe offload. AssertionError: More elements 524709888 than buffer size 100,000,000 #3506

Open
Markshilong opened this issue May 10, 2023 · 6 comments
Labels
bug Something isn't working inference

Comments

@Markshilong
Copy link

Describe the bug
I'm trying to run inference of a 54 billion model (facebook/nllb-moe-54b) using NVMe offload on my laptop with a RTX 3060 (6GB GPU memory). But I get a Error message: AssertionError: More elements 524709888 than buffer size 100,000,000

Full error message is:

(deepspeed) mark@lsl-pc:~/Research/accelerate/examples$ deepspeed --num_gpus 1 nllb_ZeRO_inference.py 
[2023-05-10 17:29:06,119] [WARNING] [runner.py:191:fetch_hostfile] Unable to find hostfile, will proceed with training with local resources only.
[2023-05-10 17:29:06,127] [INFO] [runner.py:541:main] cmd = /home/mark/anaconda3/envs/deepspeed/bin/python -u -m deepspeed.launcher.launch --world_info=eyJsb2NhbGhvc3QiOiBbMF19 --master_addr=127.0.0.1 --master_port=29500 --enable_each_rank_log=None nllb_ZeRO_inference.py
[2023-05-10 17:29:07,275] [INFO] [launch.py:229:main] WORLD INFO DICT: {'localhost': [0]}
[2023-05-10 17:29:07,275] [INFO] [launch.py:235:main] nnodes=1, num_local_procs=1, node_rank=0
[2023-05-10 17:29:07,275] [INFO] [launch.py:246:main] global_rank_mapping=defaultdict(<class 'list'>, {'localhost': [0]})
[2023-05-10 17:29:07,275] [INFO] [launch.py:247:main] dist_world_size=1
[2023-05-10 17:29:07,275] [INFO] [launch.py:249:main] Setting CUDA_VISIBLE_DEVICES=0
[2023-05-10 17:29:08,390] [INFO] [comm.py:622:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
[2023-05-10 17:29:08,611] [WARNING] [config_utils.py:69:_process_deprecated_field] Config parameter stage3_gather_fp16_weights_on_model_save is deprecated use gather_16bit_weights_on_model_save instead
[2023-05-10 17:29:11,725] [INFO] [utils.py:30:print_object] AsyncPartitionedParameterSwapper:
[2023-05-10 17:29:11,725] [INFO] [utils.py:34:print_object]   aio_config ................... {'block_size': 1048576, 'queue_depth': 8, 'thread_count': 1, 'single_submit': False, 'overlap_events': True}
[2023-05-10 17:29:11,725] [INFO] [utils.py:34:print_object]   aio_handle ................... <class 'async_io.aio_handle'>
[2023-05-10 17:29:11,725] [INFO] [utils.py:34:print_object]   aligned_bytes ................ 1024
[2023-05-10 17:29:11,725] [INFO] [utils.py:34:print_object]   aligned_elements_per_buffer .. 100000256
[2023-05-10 17:29:11,726] [INFO] [utils.py:34:print_object]   available_buffer_ids ......... [0, 1, 2, 3, 4]
[2023-05-10 17:29:11,726] [INFO] [utils.py:34:print_object]   available_numel .............. 0
[2023-05-10 17:29:11,726] [INFO] [utils.py:34:print_object]   available_params ............. set()
[2023-05-10 17:29:11,726] [INFO] [utils.py:34:print_object]   dtype ........................ torch.float16
[2023-05-10 17:29:11,726] [INFO] [utils.py:34:print_object]   elements_per_buffer .......... 100,000,000
[2023-05-10 17:29:11,726] [INFO] [utils.py:34:print_object]   id_to_path ................... {}
[2023-05-10 17:29:11,726] [INFO] [utils.py:34:print_object]   inflight_numel ............... 0
[2023-05-10 17:29:11,726] [INFO] [utils.py:34:print_object]   inflight_params .............. []
[2023-05-10 17:29:11,726] [INFO] [utils.py:34:print_object]   inflight_swap_in_buffers ..... []
[2023-05-10 17:29:11,726] [INFO] [utils.py:34:print_object]   invalid_buffer ............... 1.0
[2023-05-10 17:29:11,726] [INFO] [utils.py:34:print_object]   min_aio_bytes ................ 1048576
[2023-05-10 17:29:11,726] [INFO] [utils.py:34:print_object]   numel_alignment .............. 512
[2023-05-10 17:29:11,726] [INFO] [utils.py:34:print_object]   param_buffer_count ........... 5
[2023-05-10 17:29:11,726] [INFO] [utils.py:34:print_object]   param_id_to_buffer_id ........ {}
[2023-05-10 17:29:11,726] [INFO] [utils.py:34:print_object]   param_id_to_numel ............ {}
[2023-05-10 17:29:11,726] [INFO] [utils.py:34:print_object]   param_id_to_swap_buffer ...... {}
[2023-05-10 17:29:11,726] [INFO] [utils.py:34:print_object]   partitioned_swap_buffer ...... None
[2023-05-10 17:29:11,726] [INFO] [utils.py:34:print_object]   partitioned_swap_pool ........ None
[2023-05-10 17:29:11,726] [INFO] [utils.py:34:print_object]   pending_reads ................ 0
[2023-05-10 17:29:11,726] [INFO] [utils.py:34:print_object]   pending_writes ............... 0
[2023-05-10 17:29:11,726] [INFO] [utils.py:34:print_object]   reserved_buffer_ids .......... []
[2023-05-10 17:29:11,726] [INFO] [utils.py:34:print_object]   swap_config .................. device='nvme' nvme_path=PosixPath('/home/mark/Research/nvme_offload_path') buffer_count=5 buffer_size=100,000,000 max_in_cpu=1,000,000,000 pin_memory=True
[2023-05-10 17:29:11,726] [INFO] [utils.py:34:print_object]   swap_element_size ............ 2
[2023-05-10 17:29:11,726] [INFO] [utils.py:34:print_object]   swap_folder .................. /home/mark/Research/nvme_offload_path/zero_stage_3/float16params/rank0
[2023-05-10 17:29:11,726] [INFO] [utils.py:34:print_object]   swap_out_params .............. []
[2023-05-10 17:29:11,778] [INFO] [partition_parameters.py:454:__exit__] finished initializing model with 0.52B parameters
Traceback (most recent call last):
  File "/home/mark/Research/accelerate/examples/nllb_ZeRO_inference.py", line 187, in <module>
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)#, low_cpu_mem_usage=True)
  File "/home/mark/anaconda3/envs/deepspeed/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py", line 471, in from_pretrained
    return model_class.from_pretrained(
  File "/home/mark/anaconda3/envs/deepspeed/lib/python3.10/site-packages/transformers/modeling_utils.py", line 2629, in from_pretrained
    model = cls(config, *model_args, **model_kwargs)
  File "/home/mark/anaconda3/envs/deepspeed/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 382, in wrapper
    f(module, *args, **kwargs)
  File "/home/mark/anaconda3/envs/deepspeed/lib/python3.10/site-packages/transformers/models/nllb_moe/modeling_nllb_moe.py", line 1658, in __init__
    self.model = NllbMoeModel(config)
  File "/home/mark/anaconda3/envs/deepspeed/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 382, in wrapper
    f(module, *args, **kwargs)
  File "/home/mark/anaconda3/envs/deepspeed/lib/python3.10/site-packages/transformers/models/nllb_moe/modeling_nllb_moe.py", line 1517, in __init__
    self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
  File "/home/mark/anaconda3/envs/deepspeed/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 389, in wrapper
    self._post_init_method(module)
  File "/home/mark/anaconda3/envs/deepspeed/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 822, in _post_init_method
    param.partition()
  File "/home/mark/anaconda3/envs/deepspeed/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 948, in partition
    self._partition(param_list, has_been_updated=has_been_updated)
  File "/home/mark/anaconda3/envs/deepspeed/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 1086, in _partition
    self._partition_param(param, has_been_updated=has_been_updated)
  File "/home/mark/anaconda3/envs/deepspeed/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/home/mark/anaconda3/envs/deepspeed/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 1136, in _partition_param
    buffer = self.param_swapper.get_buffer(param, partition_size)
  File "/home/mark/anaconda3/envs/deepspeed/lib/python3.10/site-packages/deepspeed/runtime/swap_tensor/partitioned_param_swapper.py", line 343, in get_buffer
    assert numel < self.elements_per_buffer, f"More elements {numel} than buffer size {self.elements_per_buffer}"
AssertionError: More elements 524709888 than buffer size 100,000,000
[2023-05-10 17:29:13,282] [INFO] [launch.py:428:sigkill_handler] Killing subprocess 17152
[2023-05-10 17:29:13,282] [ERROR] [launch.py:434:sigkill_handler] ['/home/mark/anaconda3/envs/deepspeed/bin/python', '-u', 'nllb_ZeRO_inference.py', '--local_rank=0'] exits with return code = 1

The ds_config is:

ds_config = {
    "fp16": {
        "enabled": False
    },
    "bf16": {
        "enabled": False
    },
    "zero_optimization": {
        "stage": 3,
        "offload_param": {
            "device": "nvme",
            "nvme_path": "/home/mark/Research/nvme_offload_path",
            "buffer_count": 6,
            "buffer_size": 6e8,
            "max_in_cpu": 1e9
        },
        "overlap_comm": True,
        "contiguous_gradients": True,
        "reduce_bucket_size": model_hidden_size * model_hidden_size,
        "stage3_prefetch_bucket_size": 0.1 * model_hidden_size * model_hidden_size,
        "stage3_max_live_parameters": 1e8,
        "stage3_max_reuse_distance": 1e8,
        "stage3_param_persistence_threshold": 10 * model_hidden_size
    },
    "aio": {
        "block_size": 262144,
        "queue_depth": 32,
        "thread_count": 1,
        "single_submit": False,
        "overlap_events": True
    },
    "steps_per_print": 2000,
    "train_batch_size": train_batch_size,
    "train_micro_batch_size_per_gpu": 1,
    "wall_clock_breakdown": False
}

Then I try to change "buffer_size" in "offload_param" of ds_config from 1e8 to 6e8, but then I got 'CUDA out of memory' like this:

[2023-05-10 17:55:31,469] [INFO] [comm.py:622:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
Traceback (most recent call last):
  File "/home/mark/Research/accelerate/examples/nllb_ZeRO_inference.py", line 187, in <module>
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)#, low_cpu_mem_usage=True)
  File "/home/mark/anaconda3/envs/deepspeed/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py", line 471, in from_pretrained
    return model_class.from_pretrained(
  File "/home/mark/anaconda3/envs/deepspeed/lib/python3.10/site-packages/transformers/modeling_utils.py", line 2624, in from_pretrained
    init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config())] + init_contexts
  File "/home/mark/anaconda3/envs/deepspeed/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 762, in __init__
    self.param_swapper = AsyncPartitionedParameterSwapper(_ds_config, self.dtype)
  File "/home/mark/anaconda3/envs/deepspeed/lib/python3.10/site-packages/deepspeed/runtime/swap_tensor/partitioned_param_swapper.py", line 45, in __init__
    self._configure_aio(ds_config)
  File "/home/mark/anaconda3/envs/deepspeed/lib/python3.10/site-packages/deepspeed/runtime/swap_tensor/partitioned_param_swapper.py", line 107, in _configure_aio
    self.buffers = get_accelerator().pin_memory(
  File "/home/mark/anaconda3/envs/deepspeed/lib/python3.10/site-packages/deepspeed/accelerator/cuda_accelerator.py", line 217, in pin_memory
    return tensor.pin_memory()
RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

[2023-05-10 17:55:35,370] [INFO] [launch.py:428:sigkill_handler] Killing subprocess 19108
[2023-05-10 17:55:35,371] [ERROR] [launch.py:434:sigkill_handler] ['/home/mark/anaconda3/envs/deepspeed/bin/python', '-u', 'nllb_ZeRO_inference.py', '--local_rank=0'] exits with return code = 1

If I have a GPU with larger memory, it may works but how can I run NVMe offload on this 6GB memory GPU?

To Reproduce
Steps to reproduce the behavior:
My script is modified from huggingface/transformers#16616
Here is my script. I use 'deepspeed --num_gpus 1 nllb_ZeRO_inference.py' to run.

#!/usr/bin/env python

# from: https://github.com/huggingface/transformers/issues/16616
# This script demonstrates how to use Deepspeed ZeRO in an inference mode when one can't fit a model
# into a single GPU
#
# 1. Use 1 GPU with CPU offload
# 2. Or use multiple GPUs instead
#
# First you need to install deepspeed: pip install deepspeed
#
# Here we use a 3B "bigscience/T0_3B" model which needs about 15GB GPU RAM - so 1 largish or 2
# small GPUs can handle it. or 1 small GPU and a lot of CPU memory.
#
# To use a larger model like "bigscience/T0" which needs about 50GB, unless you have an 80GB GPU -
# you will need 2-4 gpus. And then you can adapt the script to handle more gpus if you want to
# process multiple inputs at once.
#
# The provided deepspeed config also activates CPU memory offloading, so chances are that if you
# have a lot of available CPU memory and you don't mind a slowdown you should be able to load a
# model that doesn't normally fit into a single GPU. If you have enough GPU memory the program will
# run faster if you don't want offload to CPU - so disable that section then.
#
# To deploy on 1 gpu:
#
# deepspeed --num_gpus 1 t0.py
# or:
# python -m torch.distributed.run --nproc_per_node=1 t0.py
#
# To deploy on 2 gpus:
#
# deepspeed --num_gpus 2 t0.py
# or:
# python -m torch.distributed.run --nproc_per_node=2 t0.py


from transformers import AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM
from transformers.deepspeed import HfDeepSpeedConfig
import deepspeed
import os
import torch

os.environ["TOKENIZERS_PARALLELISM"] = "False"  # To avoid warnings about parallelism in tokenizers

# distributed setup
local_rank = int(os.getenv("LOCAL_RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
torch.cuda.set_device(local_rank)
deepspeed.init_distributed()

# model_name = "bigscience/T0"
# model_name = "bigscience/T0_3B"
model_name = "facebook/nllb-moe-54b"

config = AutoConfig.from_pretrained(model_name)
model_hidden_size = config.d_model

# batch size has to be divisible by world_size, but can be bigger than world_size
train_batch_size = 1 * world_size

# ds_config notes
#
# - enable bf16 if you use Ampere or higher GPU - this will run in mixed precision and will be
# faster.
#
# - for older GPUs you can enable fp16, but it'll only work for non-bf16 pretrained models - e.g.
# all official t5 models are bf16-pretrained
#
# - set offload_param.device to "none" or completely remove the `offload_param` section if you don't
# - want CPU offload
#
# - if using `offload_param` you can manually finetune stage3_param_persistence_threshold to control
# - which params should remain on gpus - the larger the value the smaller the offload size
#
# For indepth info on Deepspeed config see
# https://huggingface.co/docs/transformers/main/main_classes/deepspeed

# XXX: modified this script to use nvme offload so need to explain the new configs, but the key is
# to change the path to `nvme_path`

# keeping the same format as json for consistency, except it uses lower case for True/False
# fmt: off
ds_config = {
    "fp16": {
        "enabled": False
    },
    "bf16": {
        "enabled": False
    },
    "zero_optimization": {
        "stage": 3,
        "offload_param": {
            "device": "nvme",
            "nvme_path": "/home/mark/Research/nvme_offload_path",
            "buffer_count": 6,
            "buffer_size": 1e8,
            "max_in_cpu": 1e9
        },
        "overlap_comm": True,
        "contiguous_gradients": True,
        "reduce_bucket_size": model_hidden_size * model_hidden_size,
        "stage3_prefetch_bucket_size": 0.1 * model_hidden_size * model_hidden_size,
        "stage3_max_live_parameters": 1e8,
        "stage3_max_reuse_distance": 1e8,
        "stage3_param_persistence_threshold": 10 * model_hidden_size
    },
    "aio": {
        "block_size": 262144,
        "queue_depth": 32,
        "thread_count": 1,
        "single_submit": False,
        "overlap_events": True
    },
    "steps_per_print": 2000,
    "train_batch_size": train_batch_size,
    "train_micro_batch_size_per_gpu": 1,
    "wall_clock_breakdown": False
}
# fmt: on

# next line instructs transformers to partition the model directly over multiple gpus using
# deepspeed.zero.Init when model's `from_pretrained` method is called.
#
# **it has to be run before loading the model AutoModelForSeq2SeqLM.from_pretrained(model_name)**
#
# otherwise the model will first be loaded normally and only partitioned at forward time which is
# less efficient and when there is little CPU RAM may fail
dschf = HfDeepSpeedConfig(ds_config)  # keep this object alive

# now a model can be loaded.
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)#, low_cpu_mem_usage=True)

# initialise Deepspeed ZeRO and store only the engine object
ds_engine = deepspeed.initialize(model=model, config_params=ds_config)[0]
ds_engine.module.eval()  # inference

# Deepspeed ZeRO can process unrelated inputs on each GPU. So for 2 gpus you process 2 inputs at once.
# If you use more GPUs adjust for more.
# And of course if you have just one input to process you then need to pass the same string to both gpus
# If you use only one GPU, then you will have only rank 0.
rank = torch.distributed.get_rank()
if rank == 0:
    text_in = "what do you think of president Obama?"
elif rank == 1:
    text_in = "Is this review positive or negative? Review: this is the worst restaurant ever"


tokenizer = AutoTokenizer.from_pretrained(model_name)
inputs = tokenizer.encode(text_in, return_tensors="pt",padding = True).to(device=local_rank)
#from transformers.deepspeed import is_deepspeed_zero3_enabled
#print(f"Deepspeed 3 is enabled: {is_deepspeed_zero3_enabled()}")
with torch.no_grad():
    outputs = ds_engine.module.generate(inputs, synced_gpus=True)
text_out = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"rank{rank}:\n   in={text_in}\n  out={text_out}")

Screenshots
image

System info (please complete the following information):

  • OS: ubuntu 22.04
  • Memory: 16GB and 10GB Swap
  • GPU count and types: 1 RTX 3060 laptop with 6GB memory
  • (if applicable) what DeepSpeed-MII version are you using
  • (if applicable) Hugging Face Transformers/Accelerate/etc. versions
  • Python version: 3.10
@Markshilong Markshilong added bug Something isn't working inference labels May 10, 2023
@andre-bauer
Copy link

any updates on this, have the same issue?

@tjruwase
Copy link
Contributor

Actually, I just noticed @Markshilong had a tried changing buffer_size to 6e8 and got a CPU OOM because of memory pinning. This OOM is likely because the system cannot provide the required ~7.2GB of pinned CPU memory.

@andre-bauer, can you please share the result of changing buffer_size to 6e8?

@tjruwase
Copy link
Contributor

Another thing to try is to disable param prefetching/caching and reduce buffer_count, as shown below.

 "zero_optimization": {
        "stage": 3,
        "offload_param": {
            "device": "nvme",
            "nvme_path": "/home/mark/Research/nvme_offload_path",
            "buffer_count": 2,
            "buffer_size": 6e8,
            "max_in_cpu": 0
        },
        "reduce_bucket_size": model_hidden_size * model_hidden_size,
        "stage3_prefetch_bucket_size": 0,
        "stage3_max_live_parameters": 1e8,
        "stage3_max_reuse_distance": 0,
        "stage3_param_persistence_threshold": 10 * model_hidden_size
    },  

@andre-bauer
Copy link

@tjruwase Thank you for your answer, changeing buffer_size fixed this issue, but now I get OOM during first inference step. But I think I need to play around with stage3_max_live_parameters and stage3_param_persistence_threshold right? any suggestions what to try first. I am running bloom-176B on L4s (24GB memory)

@tjruwase
Copy link
Contributor

@andre-bauer, can you please share a stack trace of the OOM?

@0781532
Copy link

0781532 commented Jan 22, 2024

Any update on this issue. I got the same error.

[2024-01-22 16:50:44,316] [INFO] [utils.py:799:see_memory_usage] CPU Virtual Memory: used = 26.4 GB, percent = 84.6%
[2024-01-22 16:50:49,553] [INFO] [launch.py:315:sigkill_handler] Killing subprocess 1080504

I am fine-tuning GPT-3 6.7B with a single GPU of RTX 3090 24G memory.

This is my config file:

{
"train_batch_size" : CONFIG_BATCH_SIZE,
"train_micro_batch_size_per_gpu": CONFIG_MBSIZE,
"steps_per_print": LOG_INTERVAL,

"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "nvme",
"nvme_path": "nvme",
"pin_memory": true,
"ratio": 0.3,
"buffer_count": 4,
"fast_init": false
},
"offload_param": {
"device": "nvme",
"nvme_path": "nvme",
"pin_memory": true,
"buffer_count": 5,
"buffer_size": 1e9,
"max_in_cpu": 1e9
},

    "overlap_comm": true,
    "contiguous_gradients": true,
    "sub_group_size": 1e9,
    "reduce_bucket_size": "auto",
    "stage3_prefetch_bucket_size": 0,
    "stage3_param_persistence_threshold": "auto",
    "stage3_max_live_parameters": 1e8,
    "stage3_max_reuse_distance": 0,
    "stage3_gather_16bit_weights_on_model_save": true
},

"gradient_clipping": 1.0,
"prescale_gradients":false,

"fp16": {
"enabled": CONFIG_FP16_ENABLED,
"loss_scale": 0,
"loss_scale_window": 500,
"hysteresis": 2,
"min_loss_scale": 1,
"initial_scale_power": 11
},

"bf16": {
"enabled": CONFIG_BF16_ENABLED
},

"wall_clock_breakdown" : false
}


This is my history:

[2024-01-22 16:50:41,962] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed info: version=0.12.7+870ae041, git-hash=870ae041, git-branch=master
[2024-01-22 16:50:42,005] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed Flops Profiler Enabled: False
[2024-01-22 16:50:42,006] [INFO] [logging.py:96:log_dist] [Rank 0] Using client Optimizer as basic optimizer
[2024-01-22 16:50:42,006] [INFO] [logging.py:96:log_dist] [Rank 0] Removing param_group that has no 'params' in the basic Optimizer
[2024-01-22 16:50:42,012] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed Basic Optimizer = DeepSpeedCPUAdam
[2024-01-22 16:50:42,012] [INFO] [utils.py:56:is_zero_supported_optimizer] Checking ZeRO support for optimizer=DeepSpeedCPUAdam type=<class 'deepspeed.ops.adam.cpu_adam.DeepSpeedCPUAdam'>
[2024-01-22 16:50:42,012] [INFO] [logging.py:96:log_dist] [Rank 0] Creating fp16 ZeRO stage 3 optimizer, MiCS is enabled False, Hierarchical params gather False
[2024-01-22 16:50:42,012] [INFO] [logging.py:96:log_dist] [Rank 0] Creating torch.bfloat16 ZeRO stage 3 optimizer
[2024-01-22 16:50:42,049] [INFO] [utils.py:791:see_memory_usage] Stage 3 initialize beginning
[2024-01-22 16:50:42,049] [INFO] [utils.py:792:see_memory_usage] MA 3.78 GB Max_MA 4.01 GB CA 4.04 GB Max_CA 4 GB
[2024-01-22 16:50:42,049] [INFO] [utils.py:799:see_memory_usage] CPU Virtual Memory: used = 26.39 GB, percent = 84.6%
[2024-01-22 16:50:42,050] [INFO] [stage3.py:128:init] Reduce bucket size 500,000,000
[2024-01-22 16:50:42,050] [INFO] [stage3.py:129:init] Prefetch bucket size 0
[2024-01-22 16:50:42,085] [INFO] [utils.py:791:see_memory_usage] DeepSpeedZeRoOffload initialize [begin]
[2024-01-22 16:50:42,086] [INFO] [utils.py:792:see_memory_usage] MA 3.78 GB Max_MA 3.78 GB CA 4.04 GB Max_CA 4 GB
[2024-01-22 16:50:42,086] [INFO] [utils.py:799:see_memory_usage] CPU Virtual Memory: used = 26.39 GB, percent = 84.6%
[2024-01-22 16:50:42,765] [INFO] [utils.py:30:print_object] AsyncPartitionedParameterSwapper:
[2024-01-22 16:50:42,765] [INFO] [utils.py:34:print_object] aio_config ................... {'block_size': 1048576, 'queue_depth': 8, 'thread_count': 1, 'single_submit': False, 'overlap_events': True}
[2024-01-22 16:50:42,765] [INFO] [utils.py:34:print_object] aio_handle ................... <class 'async_io.aio_handle'>
[2024-01-22 16:50:42,765] [INFO] [utils.py:34:print_object] aligned_bytes ................ 1024
[2024-01-22 16:50:42,765] [INFO] [utils.py:34:print_object] aligned_elements_per_buffer .. 1000000000
[2024-01-22 16:50:42,765] [INFO] [utils.py:34:print_object] available_buffer_ids ......... [0, 1, 2, 3, 4]
[2024-01-22 16:50:42,765] [INFO] [utils.py:34:print_object] available_numel .............. 0
[2024-01-22 16:50:42,765] [INFO] [utils.py:34:print_object] available_params ............. set()
[2024-01-22 16:50:42,765] [INFO] [utils.py:34:print_object] dtype ........................ torch.bfloat16
[2024-01-22 16:50:42,765] [INFO] [utils.py:34:print_object] elements_per_buffer .......... 1000000000
[2024-01-22 16:50:42,765] [INFO] [utils.py:34:print_object] id_to_path ................... {}
[2024-01-22 16:50:42,765] [INFO] [utils.py:34:print_object] inflight_numel ............... 0
[2024-01-22 16:50:42,765] [INFO] [utils.py:34:print_object] inflight_params .............. []
[2024-01-22 16:50:42,765] [INFO] [utils.py:34:print_object] inflight_swap_in_buffers ..... []
[2024-01-22 16:50:42,765] [INFO] [utils.py:34:print_object] invalid_buffer ............... 1.0
[2024-01-22 16:50:42,765] [INFO] [utils.py:34:print_object] min_aio_bytes ................ 1048576
[2024-01-22 16:50:42,765] [INFO] [utils.py:34:print_object] numel_alignment .............. 512
[2024-01-22 16:50:42,765] [INFO] [utils.py:34:print_object] param_buffer_count ........... 5
[2024-01-22 16:50:42,765] [INFO] [utils.py:34:print_object] param_id_to_buffer_id ........ {}
[2024-01-22 16:50:42,765] [INFO] [utils.py:34:print_object] param_id_to_numel ............ {}
[2024-01-22 16:50:42,765] [INFO] [utils.py:34:print_object] param_id_to_swap_buffer ...... {}
[2024-01-22 16:50:42,765] [INFO] [utils.py:34:print_object] partitioned_swap_buffer ...... None
[2024-01-22 16:50:42,765] [INFO] [utils.py:34:print_object] partitioned_swap_pool ........ None
[2024-01-22 16:50:42,765] [INFO] [utils.py:34:print_object] pending_reads ................ 0
[2024-01-22 16:50:42,765] [INFO] [utils.py:34:print_object] pending_writes ............... 0
[2024-01-22 16:50:42,765] [INFO] [utils.py:34:print_object] reserved_buffer_ids .......... []
[2024-01-22 16:50:42,765] [INFO] [utils.py:34:print_object] swap_config .................. device='nvme' nvme_path=PosixPath('nvme') buffer_count=5 buffer_size=1000000000 max_in_cpu=1000000000 pin_memory=True
[2024-01-22 16:50:42,765] [INFO] [utils.py:34:print_object] swap_element_size ............ 2
[2024-01-22 16:50:42,765] [INFO] [utils.py:34:print_object] swap_folder .................. nvme/zero_stage_3/bfloat16params/rank0
[2024-01-22 16:50:42,765] [INFO] [utils.py:34:print_object] swap_out_params .............. []
Parameter Offload: Total persistent parameters: 803840 in 194 params
[2024-01-22 16:50:44,239] [INFO] [utils.py:791:see_memory_usage] DeepSpeedZeRoOffload initialize [end]
[2024-01-22 16:50:44,239] [INFO] [utils.py:792:see_memory_usage] MA 0.0 GB Max_MA 3.78 GB CA 4.04 GB Max_CA 4 GB
[2024-01-22 16:50:44,239] [INFO] [utils.py:799:see_memory_usage] CPU Virtual Memory: used = 26.4 GB, percent = 84.6%
Using /home/tflow/.cache/torch_extensions/py310_cu116 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/tflow/.cache/torch_extensions/py310_cu116/fused_adam/build.ninja...
Building extension module fused_adam...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module fused_adam...
Time to load fused_adam op: 0.037041664123535156 seconds
[2024-01-22 16:50:44,315] [INFO] [utils.py:791:see_memory_usage] Before creating fp16 partitions
[2024-01-22 16:50:44,315] [INFO] [utils.py:792:see_memory_usage] MA 0.0 GB Max_MA 0.0 GB CA 4.04 GB Max_CA 4 GB
[2024-01-22 16:50:44,316] [INFO] [utils.py:799:see_memory_usage] CPU Virtual Memory: used = 26.4 GB, percent = 84.6%
[2024-01-22 16:50:49,553] [INFO] [launch.py:315:sigkill_handler] Killing subprocess 1080504

Big appreciation for any help.
Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working inference
Projects
None yet
Development

No branches or pull requests

4 participants