-
Notifications
You must be signed in to change notification settings - Fork 679
Open
Description
Error:
[rank2]: Traceback (most recent call last):
[rank2]: File "/workspace/torchtune/recipes/full_finetune_distributed.py", line 1072, in <module>
[rank2]: sys.exit(recipe_main())
[rank2]: ^^^^^^^^^^^^^
[rank2]: File "/workspace/torchtune/torchtune/config/_parse.py", line 99, in wrapper
[rank2]: sys.exit(recipe_main(conf))
[rank2]: ^^^^^^^^^^^^^^^^^
[rank2]: File "/workspace/torchtune/recipes/full_finetune_distributed.py", line 1067, in recipe_main
[rank2]: recipe.train()
[rank2]: File "/workspace/torchtune/recipes/full_finetune_distributed.py", line 916, in train
[rank2]: current_loss.backward()
[rank2]: File "/opt/conda/lib/python3.11/site-packages/torch/_tensor.py", line 648, in backward
[rank2]: torch.autograd.backward(
[rank2]: File "/opt/conda/lib/python3.11/site-packages/torch/autograd/__init__.py", line 354, in backward
[rank2]: _engine_run_backward(
[rank2]: File "/opt/conda/lib/python3.11/site-packages/torch/autograd/graph.py", line 824, in _engine_run_backward
[rank2]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/opt/conda/lib/python3.11/site-packages/torch/autograd/function.py", line 307, in apply
[rank2]: return user_fn(self, *args)
[rank2]: ^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/opt/conda/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2129, in backward
[rank2]: all_args = _backward_prologue_functional(
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/opt/conda/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1662, in _backward_prologue_functional
[rank2]: flat_processed_tangents = list(
[rank2]: ^^^^^
[rank2]: File "/opt/conda/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1665, in <genexpr>
[rank2]: AOTDispatchAutograd.process_runtime_tangent(
[rank2]: File "/opt/conda/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1902, in process_runtime_tangent
[rank2]: new_elem, elem_leaves = AOTDispatchAutograd.process_runtime_tangent(
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/opt/conda/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1873, in process_runtime_tangent
[rank2]: raise RuntimeError(
[rank2]: RuntimeError:
[rank2]: During the backward, we encountered a tensor subclass where we guessed its
[rank2]: metadata incorrectly.
[rank2]: Expected metadata: {'_orig_dtype': torch.bfloat16, '_linear_mm_config': LinearMMConfig(output=ScaledMMConfig(emulate=False, use_fast_accum=True, fp8_output=False, pad_inner_dim=False), grad_input=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False), grad_weight=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False)), '_gemm_input_role': <GemmInputRole.INPUT: 'input'>, '_axiswise_dim': None}, expected type: <class 'torchao.float8.float8_tensor.Float8Tensor'>
[rank2]: Runtime metadata: None, runtime type: <class 'torch.Tensor'>
[rank2]: shape: torch.Size([1, 1024, 4096])
[rank2]: To fix this, your tensor subclass must implement the dunder method __force_to_same_metadata__.
Error tested on custom branch incorporating other fixes to TP (should be reproducible on main), with ao from pytorch/ao#2154. With these branches, TP + FP8 works but compile does not.
Metadata
Metadata
Assignees
Labels
No labels