Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FSDP2 support for low-bit optimizers #484

Merged
merged 43 commits into from
Jul 16, 2024

Conversation

gau-nernst
Copy link
Collaborator

@gau-nernst gau-nernst commented Jul 8, 2024

This PR turns out to be a larger change than expected.

Background

In my previous implementation of low-bit optimizers, I only compile the optim step for 1 param weight. Let's call this single_param_adam(). Since weights in a model can have different shapes and sizes, to avoid re-compilation:

  1. I use torch.compile(dynamic=True) to generate dynamic kernel
  2. Flatten weight and grad with .view(-1) before passing them to the compiled optim step, since tensors with different rank/ndim will be re-compiled too, even for dynamic kernel.

Support FSDP and DTensor

Everything was fine for normal tensors. To support FSDP, the following needs to be done

  1. As param and grad are now DTensor, optim state should also be DTensor. (NOTE: this is more of a guideline than a strict requirement. The ideal state is that the optimizer shouldn't need to "know" that param and grad are DTensor) (NOTE 2: DTensor is also implemented as tensor subclass).
  2. We can torch.compile DTensor with the local_tensor being our quantized subclass.

I soon discovered that torch.compile cannot generate dynamic kernel for sharded dimension (#490 for a small reproduction snippet) + using .view(-1) on DTensor before torch.compile give strange errors. Brian has a more detailed analysis in this comment.

Following Brian's advice, I switched to compiling static kernel for the all-params optim step instead. Briefly,

# previous impl compile this w/ dynamic=True
def single_param_adam(p, ...):
    ...

# new impl compile this w/o dynamic i.e. static compile
def param_groups_adam(param_groups, ...):
    for param_group in param_groups:
        for p in param_group:
            single_param_adam(p, ...)

Since we don't expect param_groups to change during the course of training (though technically possible), static compile is all good. Works great with DTensor. The added bonus is that 8-bit optim is now much faster, presumably thanks to static specialization + reduced overhead from repeated single-param optim step calls.

Optim Memory usage Time
PyTorch FP32 12.94 8m 18s
bnb 8-bit 8.32 GB 6m 50s
ao 8-bit (old) 8.32 GB 9m 04s
ao 8-bit (new, this PR) 8.32 GB 6m 44s
lpmm 4-bit 7.72 GB 5m 59s
ao 4-bit (old) 7.72 GB 7m 00s
ao 4-bit (new, this PR) 12.62 GB 6m 50s

Probem with 4-bit opim

Wait, but 4-bit optim memory usage explodes 🙁. It's likely due to bit-packing, as only 4-bit optim does that, and torch.compile uses a different fusion pattern when compiling param_groups_adam(). Note that the memory usage is even higher than 8-bit optim and nearly equal to FP32 optim, indicating that some dequantized optim states are materialized in global memory.

To resolve this, I thought of 2 solutions:

  • Solution 1: Since the speed improvements in this PR is not significant for 4-bit optim (7m -> 6m50s), I can use the old approach (dynamic compile for single_param_adam()) for 4-bit optim. To handle DTensor, I can unwrap it before calling single_param_adam().
  • Solution 2: Change the bit-packing design. Instead of packing 2 consecutive elements in each optim state ((state[::2] << 4) | state[1::2]), I can pack optim state 1 (exp_avg) and optim state 2 (exp_avg_sq) together ((exp_avg << 4) | exp_avg_sq). This would be easier for torch.compile to fuse ops correctly, but at the expense of (a) requiring a separate implementation for 4-bit optim, and (b) does not use subclass design anymore, as unpacking and re-packing must be explicit to create exp_avg and exp_avg_sq. It also means that it's not so easy to access exp_avg and exp_avg_sq directly anymore.
Optim Memory usage Time
ao 4-bit (old - same as solution 1) 7.72 GB 7m 00s
ao 4-bit (this PR, no 4-bit specific solution) 12.62 GB 6m 50s
ao 4-bit (this PR solution 2) 7.74 GB 6m 54s

There is no clear speed advantage between Solution 1 and Solution 2. However, when running FSDP test, Solution 2 is getting this error

Operator aten.__rshift__.Scalar does not have a sharding strategy registered.

which is required by bit-packing and unpacking. This should be an easy fix for PyTorch core (I have opened an issue here pytorch/pytorch#130671).

So overall, there is no advantage for Solution 2. Using Solution 1, once the excessive memory usage bug is fixed, we can simply remove the single_param_adam() call and replace with param_groups_adam(), while keeping the tensor subclass design.

Summary

  1. To support DTensor, switch from "dynamic single-param adam step" to "static all-params adam step". 8-bit Adam is much faster as an unexpected bonus. Increased first optim step time due to compilation.
  2. Keep "dynamic single-param adam step" for 4-bit Adam due to excessive memory usage bug. Unwarp DTensor before calling adam step since DTensor does not work well with dynamic compile.

New benchmark results

Adam impl max memory (GB) time taken for 2nd epoch accuracy
PyTorch 12.94 8m 18s 91.14
bnb 8-bit 8.31 6m 50s 90.67
ao 8-bit 8.31 6m 44s 90.63
ao FP8 E4M3 8.32 6m 35s 90.98
lpmm 4-bit 7.72 5m 59s 89.97
ao 4-bit 7.72 7m 13s 90.05

Copy link

pytorch-bot bot commented Jul 8, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/484

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 87ae147 with merge base cc871c5 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 8, 2024
@@ -48,10 +49,17 @@ def step(self, closure=None):
if grad.is_sparse:
raise RuntimeError("Sparse gradient is not supported")

# unwrap DTensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc: @wanchaol

For context: What is the advantage of not unwrapping the DTensor? The optimizer states ideally can be DTensors too, in which case they have their sharding info present on the tensor, making optimizer-state checkpointing simplified.

However, low-bit optimizers may present challenges to this, which we should understand (from the DTensor side). It would be cool to understand the pain points and blockers (whether around functionality/UX in eager mode, compile support, etc.).

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awgu Thanks for bringing this up!

@gau-nernst wondering if there's any challenges to not unwrapping DTensors? as Andrew mentioned, optimizer states should be DTensors so that it should be save/load correctly when performing distributed training. We have implemented the sharding strategies for Adam in core so if low-bit optimizers are similar to the Adam implementation in core, we just need to append the operators to the corresponding list

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have tried using DTensor directly (that contains low-bit tensor subclass as its local_tensor) but faced the following error

[rank0]:E0708 20:40:01.356000 128135323341888 torch/testing/_internal/common_distributed.py:664] AssertionError: s10 (could be from ["L['grad']._base._local_tensor.size()[1]"]) not in {s2: ["L['exp_avg']._local_tensor.codes.size()[0]", "L['exp_avg']._local_tensor.codes.size()[0]", "L['exp_avg']._local_tensor.codes.size()[0]", "L['exp_avg']._local_tensor.codes.size()[0]", "L['exp_avg']._local_tensor.codes.size()[0]"], s3: ["L['exp_avg']._local_tensor.scale.size()[0]", "L['exp_avg']._local_tensor.scale.size()[0]", "L['exp_avg']._local_tensor.scale.size()[0]", "L['exp_avg']._local_tensor.scale.size()[0]", "L['exp_avg']._local_tensor.scale.size()[0]"], s4: ["L['exp_avg']._local_tensor.qmap.size()[0]", "L['exp_avg']._local_tensor.qmap.size()[0]", "L['exp_avg']._local_tensor.qmap.size()[0]", "L['exp_avg']._local_tensor.qmap.size()[0]", "L['exp_avg']._local_tensor.qmap.size()[0]", "L['exp_avg_sq']._local_tensor.qmap.size()[0]", "L['exp_avg_sq']._local_tensor.qmap.size()[0]", "L['exp_avg_sq']._local_tensor.qmap.size()[0]", "L['exp_avg_sq']._local_tensor.qmap.size()[0]", "L['exp_avg_sq']._local_tensor.qmap.size()[0]"], s13: ["L['grad']._local_tensor.size()[0]", "L['grad']._local_tensor.size()[0]"], s11: ["L['grad']._local_tensor.storage_offset()", "L['grad']._local_tensor.storage_offset()"], s16: ["L['exp_avg_sq']._local_tensor.codes.size()[0]", "L['exp_avg_sq']._local_tensor.codes.size()[0]", "L['exp_avg_sq']._local_tensor.codes.size()[0]", "L['exp_avg_sq']._local_tensor.codes.size()[0]", "L['exp_avg_sq']._local_tensor.codes.size()[0]"], s17: ["L['exp_avg_sq']._local_tensor.scale.size()[0]", "L['exp_avg_sq']._local_tensor.scale.size()[0]", "L['exp_avg_sq']._local_tensor.scale.size()[0]", "L['exp_avg_sq']._local_tensor.scale.size()[0]", "L['exp_avg_sq']._local_tensor.scale.size()[0]"], s25: ["L['p']._local_tensor.size()[0]", "L['p']._local_tensor.size()[0]"]}. If this assert is failing, it could be due to the issue described in pytorch/pytorch#90665

I'm guessing it's due to me using .view(-1) on the DTensor. If I remove torch.compile, the test passes again. Also, I'm not very clear if using .view(-1) is a proper way? How does DTensor handle .view(-1)? Does it keep the same local tensor, or it gathers the tensor and reshard? We probably don't want the latter.

Adding on about .view(-1). I added this since torch.compile will recompile and eventually hit max cache size limit when facing different inputs, even though I tried to set dynamic=True. Perhaps we should try to fix this instead? (i.e. so that we don't need to use .view(-1) anymore, seems like the source of many problems).

I will make another branch using DTensor as optim state, so you can inspect and debug the error.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option is init optim state as DTensor, but when call the Adam step, unwarp everything to local tensor. Kinda a workaround to get distributed saving/loading, while keeping torch.compile happy.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good, I think on a Shard(0) placement DTensor (where FSDP2 params look like), the view(-1) shouldn't do any communication/reshard and it should return a Shard(0) DTensor directly, so on eager side it should behave exactly what we want, otherwise it's a bug in DTensor.

I'm not sure why we hit recompliation when turning on torch.compile, it would be great if you could create a branch for repro, and we can take a look why compile not working expected.

cc @bdhirsh

Copy link
Collaborator Author

@gau-nernst gau-nernst Jul 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whether compile cache size limit is hit depends on the model. To be more specific, it is the variety of param shapes. There are 2 main sources of recompilation for my single_param_adam() functions:

  1. Input shape changes (if dynamic=True, then it is when ndim - or tensor rank - changes)
  2. Class object - we need to compile one version for normal tensor optim state, and another version for quantized tensor subclass optim state. This is because not all params will be qualified for quantization (e.g. too small, size is not divisible by block size, or user might want to exclude certain params like embedding layer)

Just 2 sources above can result in many combinations. For a ViT model from timm that I tested here, whose params have a lot of different shapes, not doing dynamic=True and .view(-1) will hit compile cache size limit easily. Maybe for LLMs, params don't have such a variety of shapes, but it's kinda not very nice if we don't handle the case when there are a variety of param shapes.

A simple solution (without doing .view(-1) and/or dynamic=True) is to just increase cache size limit. Not sure if it is encouraged?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A simple solution .... is to just increase cache size limit. Not sure if it is encouraged?

It probably depends on your use case. One failure-mode when increasing the cache limit is if e.g. your inputs have an unbounded amount of dynamism and you end up recompiling on every new input that comes in. But if you know that you have a fixed number of parameters with distinct shapes (e.g. ~50) and you know that you will only need to recompile once per param shape, bumping the cache size limit to what you need is probably ok

There are 2 main sources of recompilation for my single_param_adam()

Mostly a curiosity question: is there anything stopping you from compiling the entire step function? Or maybe a subset of it, that performs single_param_adam() on every parameter? That way you get one giant graph with all of the params (allowing inductor to potentially fuse more things), instead of many small graphs.

Regardless, we want to fix the dynamic shape issues and this is probably a good reason to keep pushing on pytorch/pytorch#124619

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It probably depends on your use case. One failure-mode when increasing the cache limit is if e.g. your inputs have an unbounded amount of dynamism and you end up recompiling on every new input that comes in. But if you know that you have a fixed number of parameters with distinct shapes (e.g. ~50) and you know that you will only need to recompile once per param shape, bumping the cache size limit to what you need is probably ok

That's why I'm leaning towards the unwarping DTensor "hack", since dynamice=True + .view(-1) already work well (at least I think so) for non-FSDP setup, so it's natural to "hack" it to support DTensor.

is there anything stopping you from compiling the entire step function?

I thought of it but haven't tried it myself. I can give it a try.

So in the end, how should we decide which approach to close this PR? Maybe I will try the horizontal fusion that you mention first to see if it can solve all of the problems.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, re-reading @awgu's comment that was not answered

What is the advantage of not unwrapping the DTensor?

In other words, what are the problems that unwrapping DTensor might be discouraged? Quantized optim state can be wrapped by DTensor and then unwrapped before adam step too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In other words, what are the problems that unwrapping DTensor might be discouraged? Quantized optim state can be wrapped by DTensor and then unwrapped before adam step too.

Ah, you are saying you can keep the optimizer states as DTensors but just unwrap them for the actual step()? If so, I think that works. We would only not prefer it because it requires the optimizer code to be aware of DTensor. (Ideally, like in the native torch optimizers, we do not need to change the code to make it work with DTensor, but this is just the ideal state.)

if isinstance(m, TransformerBlock):
fully_shard(m, **fsdp_kwargs)
fully_shard(fsdp_model, **fsdp_kwargs)
fsdp_optim = low_bit_optim.Adam8bit(fsdp_model.parameters(), lr=1e-2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is .grad converted from high precision (fp32 or bf16) to 8bit inside Adam8bit ? Asking because FSDP2 seems to be performing gradient reduction in high precision base on fsdp_kwargs

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Grad is in original precision. To summarize

  1. optim state is dequant to FP32 (maybe I need to change it to param or grad dtype? haven't test full BF16 training)
  2. adam step as usual
  3. optim state is quantized back to 8-bit

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for explaining. that means FSDP2 logic during fwd/bwd remains the same (doing high precision gradient reduction) no matter optimizer is in high precision or low precision. The above dicussion around DTensor is the core then

@gau-nernst
Copy link
Collaborator Author

gau-nernst commented Jul 11, 2024

@bdhirsh I followed your suggestion (static compile optim step for all params in a single graph) and it worked out great! Everything works now. I needed to separate out the optim state init step (because it seems like torch.compile() can't produce fullgraph=True when there is Tensor._make_wrapper_subclass()) but it's a reasonable and simple thing to do. End2end training seems to be faster too! Re-running the benchmarks as I'm writing this.

However, there is a small problem: compile time increases. It shouldn't be a big problem for real training, but FSDP test in CI might timeout. The test timeout on my local machine, hope it doesn't in CI. (Even if CI doesn't timeout, it is also slightly not good in a sense that it increases CI time).

UPDATE: indeed the test also timeout in CI. will skip CPUOffload test

UPDATE 2: memory usage for 4-bit optimizer explodes :(

optimizer mem usage before (GB) mem usage after (GB)
FP32 (baseline) 12.94 -
8-bit 8.32 8.31
FP8 8.32 8.32
4-bit 7.72 12.62

@gau-nernst gau-nernst marked this pull request as ready for review July 15, 2024 23:19
@msaroufim msaroufim self-requested a review July 15, 2024 23:27
@msaroufim msaroufim merged commit f7571cf into pytorch:main Jul 16, 2024
13 checks passed
@gau-nernst gau-nernst deleted the low_bit_optim_fsdp2 branch July 17, 2024 05:28
@awgu
Copy link
Contributor

awgu commented Jul 23, 2024

I did a quick 1000 step run with Adam4bit on Llama3-8B on c4 dataset on 8xH100s, and Adam4bit seems to be converging well. The memory savings are not that significant though. Looking at the snapshots, it is mainly a memory fragmentation issue.

Screenshot 2024-07-22 at 8 57 38 PM Screenshot 2024-07-22 at 8 57 55 PM Screenshot 2024-07-22 at 8 58 25 PM
torchtitan diff
diff --git a/train.py b/train.py
index b7eee30..3d5a2d0 100644
--- a/train.py
+++ b/train.py
@@ -116,6 +116,11 @@ def build_optimizers(model_parts, job_config: JobConfig):
             optimizer = torch.optim.Adam(model.parameters(), **optimizer_kwargs)
         elif name == "AdamW":
             optimizer = torch.optim.AdamW(model.parameters(), **optimizer_kwargs)
+        elif name == "AdamW4bit":
+            from torchao.prototype import low_bit_optim
+            optimizer_kwargs.pop("fused")
+            optimizer_kwargs.pop("foreach")
+            optimizer = low_bit_optim.AdamW4bit(model.parameters(), **optimizer_kwargs)
         else:
             raise NotImplementedError(f"Optimizer {name} not added.")
Llama3-8B Config
# torchtitan Config.toml
# NOTE: this toml config is a preset for 64 A100 GPUs.

[job]
dump_folder = "./outputs"
description = "Llama 3 8B training"

[profiling]
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 10
enable_tensorboard = true
save_tb_folder = "tb"

[model]
name = "llama3"
flavor = "8B"
norm_type = "compiled_rmsnorm"  # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer]
name = "AdamW4bit"
lr = 3e-4

[training]
batch_size = 1
seq_len = 8192
warmup_steps = 200  # lr scheduler warm up
max_norm = 1.0  # grad norm clipping
steps = 1000
data_parallel_degree = -1
tensor_parallel_degree = 1
enable_float8_linear = false
compile = true
dataset = "c4"

[experimental]
pipeline_parallel_degree = 1

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval_type = "steps"
interval = 500
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'none'  # ['none', 'selective', 'full']
selective_ac_option = 'op'  # 'int' = ac every positive int layer or 'op', ac based on ops policy

@gau-nernst
Copy link
Collaborator Author

@awgu That looks awesome! Regarding memory fragmentation, do you know how we can improve it? I saw some people using os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" to improve memory allocation.

Also, have you tried using AdamW8bit? It seems like 8-bit Adam may hit recompile cache size limit in multi-gpu FSDP, though no problem with single GPU, so I think we should investigate this issue more. 4-bit Adam doesn't have this issue because right now we do dynamic compile for single param (instead of static compile for all params) for 4-bit Adam.

@awgu
Copy link
Contributor

awgu commented Jul 23, 2024

Just for more detail, we see that the segment (white rectangles) allocations are pretty similar for AdamW4bit and normal AdamW, just that fewer are actually filled (i.e. have active blocks) for AdamW4bit:
Screenshot 2024-07-22 at 9 44 12 PM
Screenshot 2024-07-22 at 9 44 23 PM

Expandable segments should help. I am hoping PyTorch can come up with a more principled way (I have recently tried to start some discussion around this internally): namely, parameters and optimizer states are persistent through training with fixed/known memory cost, so we can actually preallocate them with 0 fragmentation in theory.

For AdamW8bit, I see an error trace like:

Root Cause (first observed failure):
[0]:
  time      : 2024-07-22_18:59:23
  host      : devgpu011.cco1.facebook.com
  rank      : 6 (local_rank: 6)
  exitcode  : 1 (pid: 594488)
  error_file: /tmp/torchelastic_jbbq8n3h/none_r0u1t1aw/attempt_0/6/error.json
  traceback : Traceback (most recent call last):
    File "/data/users/andgu/pytorch/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
      return f(*args, **kwargs)
    File "/data/users/andgu/torchtrain/train.py", line 427, in main
      optimizers.step()
    File "/data/users/andgu/torchtrain/train.py", line 142, in step
      optimizer.step()
    File "/data/users/andgu/pytorch/torch/optim/lr_scheduler.py", line 136, in wrapper
      return func.__get__(opt, opt.__class__)(*args, **kwargs)
    File "/data/users/andgu/pytorch/torch/optim/optimizer.py", line 487, in wrapper
      out = func(*args, **kwargs)
    File "/data/users/andgu/pytorch/torch/utils/_contextlib.py", line 116, in decorate_context
      return func(*args, **kwargs)
    File "/home/andgu/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torchao-0.3.1-py3.10-linux-x86_64.egg/torchao/prototype/low_bit_optim/adamw.py", line 110, in step
      torch.compile(param_groups_adamw, fullgraph=True)(param_groups)
    File "/data/users/andgu/pytorch/torch/_dynamo/eval_frame.py", line 448, in _fn
      return fn(*args, **kwargs)
    File "/data/users/andgu/pytorch/torch/_dynamo/convert_frame.py", line 1170, in __call__
      return self._torchdynamo_orig_callable(
    File "/data/users/andgu/pytorch/torch/_dynamo/convert_frame.py", line 499, in __call__
      return _compile(
    File "/home/andgu/.conda/envs/pytorch-3.10/lib/python3.10/contextlib.py", line 79, in inner
      return func(*args, **kwds)
    File "/data/users/andgu/pytorch/torch/_dynamo/convert_frame.py", line 877, in _compile
      raise InternalTorchDynamoError(str(e)).with_traceback(
    File "/data/users/andgu/pytorch/torch/_dynamo/convert_frame.py", line 850, in _compile
      guarded_code = compile_inner(code, one_graph, hooks, transform)
    File "/data/users/andgu/pytorch/torch/_dynamo/utils.py", line 246, in time_wrapper
      r = func(*args, **kwargs)
    File "/data/users/andgu/pytorch/torch/_utils_internal.py", line 85, in wrapper_function
      return StrobelightCompileTimeProfiler.profile_compile_time(
    File "/data/users/andgu/pytorch/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
      return func(*args, **kwargs)
    File "/data/users/andgu/pytorch/torch/_dynamo/convert_frame.py", line 668, in compile_inner
      out_code = transform_code_object(code, transform)
    File "/data/users/andgu/pytorch/torch/_dynamo/bytecode_transformation.py", line 1284, in transform_code_object
      transformations(instructions, code_options)
    File "/data/users/andgu/pytorch/torch/_dynamo/convert_frame.py", line 194, in _fn
      return fn(*args, **kwargs)
    File "/data/users/andgu/pytorch/torch/_dynamo/convert_frame.py", line 610, in transform
      tracer.run()
    File "/data/users/andgu/pytorch/torch/_dynamo/symbolic_convert.py", line 2546, in run
      super().run()
    File "/data/users/andgu/pytorch/torch/_dynamo/symbolic_convert.py", line 910, in run
      while self.step():
    File "/data/users/andgu/pytorch/torch/_dynamo/symbolic_convert.py", line 822, in step
      self.dispatch_table[inst.opcode](self, inst)
    File "/data/users/andgu/pytorch/torch/_dynamo/symbolic_convert.py", line 514, in wrapper
      return inner_fn(self, inst)
    File "/data/users/andgu/pytorch/torch/_dynamo/symbolic_convert.py", line 1489, in CALL_FUNCTION
      self.call_function(fn, args, {})
    File "/data/users/andgu/pytorch/torch/_dynamo/symbolic_convert.py", line 760, in call_function
      self.push(fn.call_function(self, args, kwargs))
    File "/data/users/andgu/pytorch/torch/_dynamo/variables/functions.py", line 300, in call_function
      return super().call_function(tx, args, kwargs)
    File "/data/users/andgu/pytorch/torch/_dynamo/variables/functions.py", line 100, in call_function
      return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
    File "/data/users/andgu/pytorch/torch/_dynamo/symbolic_convert.py", line 766, in inline_user_function_return
      return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
    File "/data/users/andgu/pytorch/torch/_dynamo/symbolic_convert.py", line 2761, in inline_call
      return cls.inline_call_(parent, func, args, kwargs)
    File "/data/users/andgu/pytorch/torch/_dynamo/symbolic_convert.py", line 2877, in inline_call_
      tracer.run()
    File "/data/users/andgu/pytorch/torch/_dynamo/symbolic_convert.py", line 910, in run
      while self.step():
    File "/data/users/andgu/pytorch/torch/_dynamo/symbolic_convert.py", line 822, in step
      self.dispatch_table[inst.opcode](self, inst)
    File "/data/users/andgu/pytorch/torch/_dynamo/symbolic_convert.py", line 514, in wrapper
      return inner_fn(self, inst)
    File "/data/users/andgu/pytorch/torch/_dynamo/symbolic_convert.py", line 1489, in CALL_FUNCTION
      self.call_function(fn, args, {})
    File "/data/users/andgu/pytorch/torch/_dynamo/symbolic_convert.py", line 760, in call_function
      self.push(fn.call_function(self, args, kwargs))
    File "/data/users/andgu/pytorch/torch/_dynamo/variables/misc.py", line 745, in call_function
      return self.obj.call_method(tx, self.name, args, kwargs)
    File "/data/users/andgu/pytorch/torch/_dynamo/variables/tensor.py", line 507, in call_method
      return wrap_fx_proxy(
    File "/data/users/andgu/pytorch/torch/_dynamo/variables/builder.py", line 1863, in wrap_fx_proxy
      return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
    File "/data/users/andgu/pytorch/torch/_dynamo/variables/builder.py", line 1950, in wrap_fx_proxy_cls
      example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
    File "/data/users/andgu/pytorch/torch/_dynamo/utils.py", line 1848, in get_fake_value
      ret_val = wrap_fake_exception(
    File "/data/users/andgu/pytorch/torch/_dynamo/utils.py", line 1340, in wrap_fake_exception
      return fn()
    File "/data/users/andgu/pytorch/torch/_dynamo/utils.py", line 1849, in <lambda>
      lambda: run_node(tx.output, node, args, kwargs, nnmodule)
    File "/data/users/andgu/pytorch/torch/_dynamo/utils.py", line 1984, in run_node
      raise RuntimeError(make_error_message(e)).with_traceback(
    File "/data/users/andgu/pytorch/torch/_dynamo/utils.py", line 1962, in make_error_message
      return f"Failed running {op} {node.target}(*{args}, **{kwargs}):\n" + str(e)
    File "/data/users/andgu/pytorch/torch/distributed/_tensor/api.py", line 263, in __repr__
      return f"DTensor(local_tensor={self._local_tensor}, device_mesh={self._spec.mesh}, placements={self._spec.placements})"
    File "/data/users/andgu/pytorch/torch/_tensor.py", line 994, in __format__
      return object.__format__(self, format_spec)
    File "/home/andgu/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torchao-0.3.1-py3.10-linux-x86_64.egg/torchao/prototype/low_bit_optim/subclass_8bit.py", line 61, in __repr__
      f"{self.__class__.__name__}(signed={self.signed}, block_size={self.block_size}, "
    File "/home/andgu/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torchao-0.3.1-py3.10-linux-x86_64.egg/torchao/prototype/low_bit_optim/subclass_8bit.py", line 39, in block_size
      return self.codes.numel() // self.scale.numel()
  torch._dynamo.exc.InternalTorchDynamoError: integer division or modulo by zero
  
  from user code:
     File "/home/andgu/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torchao-0.3.1-py3.10-linux-x86_64.egg/torchao/prototype/low_bit_optim/adamw.py", line 117, in param_groups_adamw
      single_param_adamw(p, grad, step, exp_avg, exp_avg_sq, max_exp_avg_sq, lr, beta1, beta2, weight_decay, eps)
    File "/home/andgu/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torchao-0.3.1-py3.10-linux-x86_64.egg/torchao/prototype/low_bit_optim/adamw.py", line 140, in single_param_adamw
      new_exp_avg = exp_avg.lerp(grad, 1 - beta1)
  
  Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
  
  
  You can suppress this exception and fall back to eager by setting:
      import torch._dynamo
      torch._dynamo.config.suppress_errors = True

@gau-nernst
Copy link
Collaborator Author

The AdamW8bit error is indeed strange. I think something fails, then torch.compiler tries to print the tensor make_error_message(), which trigger another error when calling OptimState8bit.__repr__(). Will be hard for me to debug it as I don't have access to multi-GPU cluster.

Anyway, it's probably a good idea to make block_size a fixed attribute (as extra metadata) instead of calculating it on the fly as right now. I did see some weird errors with torch.compile when doing this in the past.

@property
def block_size(self):
return self.codes.numel() // self.scale.numel()

@gau-nernst
Copy link
Collaborator Author

Thinking about it again, the memory fragmentation might be partly because I create a normal DTensor and then swap the local tensor with its quantized version. A better way would be to create a quantized local tensor directly and wrap it with DTensor (either by calling DTensor.from_local() or calling DTensor constructor directly?)

if isinstance(p, DTensor):
out = torch.empty_like(p)
out._local_tensor = self._subclass_zeros(
out._local_tensor,
signed,
self.block_size,
)

@awgu
Copy link
Contributor

awgu commented Jul 24, 2024

I would need to step through the memory snapshot more systematically, but I am not sure if that is the issue because the non-quantized local tensor should be freed immediately after constructing the quantized local tensor and that memory can be reused for the next parameter's optimizer states (or next optimizer state for a given parameter). As long as you are not constructing local tensors for all parameters at once and then quantizing all of them, this should not be an issue in my understanding.

@gau-nernst
Copy link
Collaborator Author

Yup I won't say it's the only cause for memory fragmentation in this case. Python garbage collector and PyTorch memory allocator can work in mysterious ways I suppose, not an expert in this.

But I think it's a good idea to create quantized local tensor directly anyway. I did what I did above mainly because I couldn't get DTensor.from_local() to work (it required to implement quite a lot of extra ops). Is calling DTensor constructor ok? i.e. DTensor(quantized_local_tensor, spec, requires_grad=False). I saw that the constructor arguments changed from 2.3 to 2.4 (device_mesh, placements now become DTensorSpec) -> perhaps unstable/private API? So was unsure if calling DTensor constructor directly ok.

@awgu
Copy link
Contributor

awgu commented Jul 24, 2024

I think calling DTensor constructor directly is not recommended (@wanchaol does not expect anyone to do this 😃 ). The from_local() issues might be mitigated now after pytorch/pytorch#130289. With run_check=False, I think you should not need to implement any extra ops 🤔 .

@gau-nernst
Copy link
Collaborator Author

Thanks, I will try!

dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
@nighting0le01
Copy link

@gau-nernst cannot run low bit optim with FSDP2


[rank7]:   File "<frozen runpy>", line 198, in _run_module_as_main
[rank7]:   File "<frozen runpy>", line 88, in _run_code
[rank7]:   File "/nfs/asahni/multi_parallel/oct_28/training/scripts/train.py", line 226, in <module>
[rank7]:     main()
[rank7]:   File "/nfs/asahni/multi_parallel/oct_28/training/scripts/train.py", line 218, in main
[rank7]:     trainer.fit(
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 538, in fit
[rank7]:     call._call_and_handle_interrupt(
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 47, in _call_and_handle_interrupt
[rank7]:     return trainer_fn(*args, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 574, in _fit_impl
[rank7]:     self._run(model, ckpt_path=ckpt_path)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 981, in _run
[rank7]:     results = self._run_stage()
[rank7]:               ^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1025, in _run_stage
[rank7]:     self.fit_loop.run()
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py", line 205, in run
[rank7]:     self.advance()
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py", line 363, in advance
[rank7]:     self.epoch_loop.run(self._data_fetcher)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 140, in run
[rank7]:     self.advance(data_fetcher)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 269, in advance
[rank7]:     call._call_callback_hooks(trainer, "on_train_batch_end", batch_output, batch, batch_idx)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 218, in _call_callback_hooks
[rank7]:     fn(trainer, trainer.lightning_module, *args, **kwargs)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 316, in on_train_batch_end
[rank7]:     self._save_topk_checkpoint(trainer, monitor_candidates)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 387, in _save_topk_checkpoint
[rank7]:     self._save_none_monitor_checkpoint(trainer, monitor_candidates)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 715, in _save_none_monitor_checkpoint
[rank7]:     self._save_checkpoint(trainer, filepath)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 390, in _save_checkpoint
[rank7]:     trainer.save_checkpoint(filepath, self.save_weights_only)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1365, in save_checkpoint
[rank7]:     self.strategy.save_checkpoint(checkpoint, filepath, storage_options=storage_options)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/strategies/model_parallel.py", line 321, in save_checkpoint
[rank7]:     _distributed_checkpoint_save(converted_state, path)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/fabric/strategies/fsdp.py", line 867, in _distributed_checkpoint_save
[rank7]:     save(converted_state, checkpoint_id=path)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/logger.py", line 83, in wrapper
[rank7]:     result = func(*args, **kwargs)
[rank7]:              ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/utils.py", line 429, in inner_func
[rank7]:     return func(*args, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict_saver.py", line 152, in save
[rank7]:     return _save_state_dict(
[rank7]:            ^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict_saver.py", line 316, in _save_state_dict
[rank7]:     central_plan: SavePlan = distW.reduce_scatter("plan", local_step, global_step)
[rank7]:                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/utils.py", line 191, in reduce_scatter
[rank7]:     raise result
[rank7]: torch.distributed.checkpoint.api.CheckpointException: CheckpointException ranks:dict_keys([0, 1, 2, 3, 4, 5, 6, 7])
[rank7]: Traceback (most recent call last): (RANK 0)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/utils.py", line 164, in reduce_scatter
[rank7]:     local_data = map_fun()
[rank7]:                  ^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/logger.py", line 83, in wrapper
[rank7]:     result = func(*args, **kwargs)
[rank7]:              ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict_saver.py", line 303, in local_step
[rank7]:     local_plan = planner.create_local_plan()
[rank7]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/default_planner.py", line 101, in create_local_plan
[rank7]:     plan = create_default_local_save_plan(self.state_dict, self.is_coordinator)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/default_planner.py", line 399, in create_default_local_save_plan
[rank7]:     requests += _create_write_items(fqn, obj)
[rank7]:                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/planner_helpers.py", line 222, in _create_write_items
[rank7]:     return object.__create_write_items__(fqn, object)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_api.py", line 598, in __create_write_items__
[rank7]:     return [_create_write_items_for_dtensor(fqn, object)]
[rank7]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/planner_helpers.py", line 86, in _create_write_items_for_dtensor
[rank7]:     properties=TensorProperties.create_from_tensor(tensor.to_local()),
[rank7]:                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/metadata.py", line 108, in create_from_tensor
[rank7]:     pin_memory=tensor.is_pinned(),
[rank7]:                ^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torchao/utils.py", line 377, in _dispatch__torch_function__
[rank7]:     return func(*args, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torchao/utils.py", line 393, in _dispatch__torch_dispatch__
[rank7]:     raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run unimplemented operator/function: {func=}, {types=}, {arg_types=}, {kwarg_types=}")
[rank7]: NotImplementedError: OptimState8bit dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.is_pinned', overload='default')>, types=(<class 'torchao.prototype.low_bit_optim.subclass_8bit.OptimState8bit'>,), arg_types=(<class 'torchao.prototype.low_bit_optim.subclass_8bit.OptimState8bit'>,), kwarg_types={}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants