Skip to content

Representing DTensor in thunder traces #1907

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 62 commits into from
Jun 13, 2025

Conversation

kshitij12345
Copy link
Collaborator

@kshitij12345 kshitij12345 commented Mar 26, 2025

Fixes #1898

Design Doc - https://docs.google.com/document/d/1Gqb_jXrL-sSqs-D8KrZdcQinxuUSlccZBnnvbYJfYl0/edit?usp=sharing

Changes -
This PR adds support for DTensor inputs to the jitted function. Most of the additions required to support DTensor are present in thunder/torch/experimental like the DTensorProxy, related prims.

NOTE:

  • This PR just adds the basic infrastructure to be able to run a simple DTensor program (with torch.mul and no broadcast). Coverage will be followed in subsequent PRs.
  • thunderfx path has failure currently (we add a test asserting that). Will be fixed in a separate PR.

Following are the main updates:

  1. Prologue: Adds a new primitive check_dtensor_spec_repr which will match the repr of DTensorSpec of the DTensor in question (see the example below). PR also makes sure that besices the DTensorSpec there is tensor metadata check for the DTensor object as well as for the local tensor that it points to. NOTE - Other option for checking DTensorSpec would be to keep the inputs DTensorSpec in the TracingContext and prologue could verify for equality.

  2. DTensorProxy: Adds a new Proxy object to represent the DTensor. This class inherits from TensorProxy as DTensor is a tensor subclass and implements all the same methods that a tensor implements.

  3. Prims and Operations: For computation trace, we add prims and torch level operations for DTensor. We add new prims and operations instead of re-using the existing ones to prevent the executors from claiming an operation on DTensor by-mistake.

  4. Representation in trace -

Example Program

from torch.distributed.tensor import DTensor
from torch.distributed import init_device_mesh
import torch
import os

os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
mesh = init_device_mesh("cuda", (1,), mesh_dim_names=["i"])

x_dtensor = DTensor.from_local(torch.randn(2, 2), device_mesh=mesh)
w_dtensor = DTensor.from_local(torch.randn(2, 2), device_mesh=mesh)

import thunder

@thunder.jit
def fn(x, w):
    return x * w

fn(x_dtensor, w_dtensor)

Prologue Trace (relevant snippet)

# print(fn._lc_cs.last_prologue_traces[-1])
@torch.no_grad()
@no_autocast
def prologue(*args, **kwargs):
  # args: "Any"
  prims.check_len(args, 2)
  # kwargs: "Any"
  prims.check_len(kwargs, 0)
  l_x_: "DTensor cuda:0 f32[16, 16]" = args[0]
  l_w_: "DTensor cuda:0 f32[16, 16]" = args[1]
  dtensor_spec0: "<class 'NoneType'>" = l_x_._spec
  thunder.torch.experimental.dtensor_prims_and_impl.check_dtensor_spec_repr(dtensor_spec0, "DTensorSpec(mesh=DeviceMesh('cuda', [0, 1]), placements=(Shard(dim=0),), tensor_meta=TensorMeta(shape=torch.Size([16, 16]), stride=(16, 1), dtype=torch.float32))")
  t1: "cuda:0 f32[8, 16]" = l_x_._local_tensor
  prims.check_tensor_shape_and_metadata(t1, (8, 16), 'cuda:0', torch.float32, True)
  prims.check_tensor_shape_and_metadata(l_x_, (16, 16), 'cuda:0', torch.float32, True)
  dtensor_spec2: "<class 'NoneType'>" = l_w_._spec
  thunder.torch.experimental.dtensor_prims_and_impl.check_dtensor_spec_repr(dtensor_spec2, "DTensorSpec(mesh=DeviceMesh('cuda', [0, 1]), placements=(Shard(dim=0),), tensor_meta=TensorMeta(shape=torch.Size([16, 16]), stride=(16, 1), dtype=torch.float32))")
  t3: "cuda:0 f32[8, 16]" = l_w_._local_tensor
  prims.check_tensor_shape_and_metadata(t3, (8, 16), 'cuda:0', torch.float32, False)
  prims.check_tensor_shape_and_metadata(l_w_, (16, 16), 'cuda:0', torch.float32, False)

Computation Trace : There is a torch level symbol dtensor_mul which is decomposed into prims for DTensor operations.

# print(fn._lc_cs.last_traces[0])
@torch.no_grad()
@no_autocast
def computation(x, w):
  # x: "DTensor cuda:0 f32[16, 16] mesh=DeviceMesh('cuda', [0, 1]), placements=(Shard(dim=0),)"
  # w: "DTensor cuda:0 f32[16, 16] mesh=DeviceMesh('cuda', [0, 1]), placements=(Shard(dim=0),)"

  # /opt/pytorch/lightning-thunder/test_dtensor.py:21: 	    return torch.mul(x, w)
  dtensor_6 = thunder.torch.experimental.dtensor_torch_and_prims.dtensor_mul(x, w)  # dtensor_6: "DTensor cuda:0 f32[16, 16] mesh=DeviceMesh('cuda', [0, 1]), placements=(Shard(dim=0),)"
    # dtensor_6 = thunder.torch.experimental.dtensor_torch_and_prims.dtensor_mul_prim(x, w)  # dtensor_6: "DTensor cuda:0 f32[16, 16] mesh=DeviceMesh('cuda', [0, 1]), placements=(Shard(dim=0),)"
  return (dtensor_6,)

Backward Trace (initial trace)

# print(fn._lc_cs.last_backward_traces[0])
@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, _, = saved_for_backward
  # C0: "Collection"
  # None
  clear_mutable_collection(saved_for_backward)
  del saved_for_backward
  dtensor_0, = cotangents
  # dtensor_0: "DTensor cuda:0 f32[16, 16] mesh=DeviceMesh('cuda', [0, 1]), placements=(Shard(dim=0),)"
  clear_mutable_collection(cotangents)
  del cotangents
  w, x, = C0
  # w: "DTensor cuda:0 f32[16, 16] mesh=DeviceMesh('cuda', [0, 1]), placements=(Shard(dim=0),)"
  # x: "DTensor cuda:0 f32[16, 16] mesh=DeviceMesh('cuda', [0, 1]), placements=(Shard(dim=0),)"
  clear_mutable_collection(C0)
  del C0
  bw_dtensor_19 = dtensor_mul_prim(w, dtensor_0)  # bw_dtensor_19: "DTensor cuda:0 f32[16, 16] mesh=DeviceMesh('cuda', [0, 1]), placements=(Shard(dim=0),)"
    # bw_dtensor_19 = thunder.torch.experimental.dtensor_torch_and_prims.dtensor_mul_prim(w, dtensor_0)  # bw_dtensor_19: "DTensor cuda:0 f32[16, 16] mesh=DeviceMesh('cuda', [0, 1]), placements=(Shard(dim=0),)"
  del w
  bw_dtensor_22 = dtensor_mul_prim(x, dtensor_0)  # bw_dtensor_22: "DTensor cuda:0 f32[16, 16] mesh=DeviceMesh('cuda', [0, 1]), placements=(Shard(dim=0),)"
    # bw_dtensor_22 = thunder.torch.experimental.dtensor_torch_and_prims.dtensor_mul_prim(x, dtensor_0)  # bw_dtensor_22: "DTensor cuda:0 f32[16, 16] mesh=DeviceMesh('cuda', [0, 1]), placements=(Shard(dim=0),)"
  del x, dtensor_0
  return (bw_dtensor_19, bw_dtensor_22)

Thank you Masaki, Ivan and Mike for the helpful discussions and guidance!

@IvanYashchuk IvanYashchuk added the DTensor Issues about DTensor support in Thunder label Apr 2, 2025
@kshitij12345
Copy link
Collaborator Author

Gentle ping @IvanYashchuk

@kshitij12345
Copy link
Collaborator Author

With the latest merge, I am seeing failure in the test for ConstantFolding transform. The error seems legit and not sure why it used to work before (as it should have happened previously as well).

Cause -
This happens as it relies on the internal map _torch_to_thunder_function_map (from torch_fn to Symbol) -

_thunder_to_torch_function_map = {v: k for k, v in _torch_to_thunder_function_map.items()}

But this PR updates this map to be torch_fn to Callable which dispatches to correct Symbol -

def register_function_for_dtensor(torch_fn, single_device_symbol, dtensor_symbol, is_method=False):
register_function(torch_fn, dispatch_to_impl(single_device_symbol, dtensor_symbol))
if is_method:
method_name: str = torch_fn.__name__
torch_method: None | Callable = getattr(torch.Tensor, method_name, None)
register_method_for_dtensor(torch_method, single_device_symbol, dtensor_symbol)

Workaround -
I have a fix in mind for ConstantFolding to not rely on this. So, I think we should xfail the test in this PR and I will send a follow-up PR to get the ConstantFolding test passing again.

cc: @IvanYashchuk

@kshitij12345
Copy link
Collaborator Author

Ping @t-vi for stamp

Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@t-vi t-vi enabled auto-merge (squash) June 13, 2025 15:15
@t-vi t-vi merged commit d665072 into Lightning-AI:main Jun 13, 2025
49 checks passed
@t-vi
Copy link
Collaborator

t-vi commented Jun 13, 2025

@kshitij12345 you will file an issue about USE_DISTRIBUTED=OFF ?

@kshitij12345
Copy link
Collaborator Author

kshitij12345 commented Jun 13, 2025

Opened #2233 for tracking that thunder will work with PyTorch compiled without distributed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
DTensor Issues about DTensor support in Thunder
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Accept DTensor input without errors
4 participants