Skip to content

Representing DTensor in thunder traces #1907

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 41 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
5873742
dtensor support
kshitij12345 Mar 21, 2025
377125a
add comment
kshitij12345 Mar 24, 2025
7ab82f6
add more comments
kshitij12345 Mar 24, 2025
e6aa8d3
update comment
kshitij12345 Mar 24, 2025
e76fc17
add test for execpted failing cases
kshitij12345 Mar 24, 2025
eaac9f7
support for method
kshitij12345 Mar 24, 2025
94ef69d
update failing case test
kshitij12345 Mar 24, 2025
5d81851
remove generated traces
kshitij12345 Mar 26, 2025
7277753
undo pre-commit change
kshitij12345 Mar 26, 2025
a8c58e4
undo debug changes
kshitij12345 Mar 26, 2025
d87b103
update failing test to use thunder.jit
kshitij12345 Mar 26, 2025
b101161
update registration helper
kshitij12345 Mar 26, 2025
b551cb8
Apply suggestions from code review
kshitij12345 Mar 31, 2025
1c75a80
Merge branch 'main' of github.com:Lightning-AI/lightning-thunder into…
kshitij12345 Apr 1, 2025
5854c86
address review and upadte
kshitij12345 Apr 1, 2025
a778830
update dtensor proxy repr
kshitij12345 Apr 2, 2025
41990d0
Merge branch 'main' of github.com:Lightning-AI/lightning-thunder into…
kshitij12345 Apr 2, 2025
eda0277
Merge branch 'main' of github.com:Lightning-AI/lightning-thunder into…
kshitij12345 Apr 3, 2025
8abf040
Merge branch 'main' of github.com:Lightning-AI/lightning-thunder into…
kshitij12345 Apr 9, 2025
225f2e3
update jit_ext access to torchfn_to_thunder registry : test
kshitij12345 Apr 9, 2025
2b85b31
empty commit
kshitij12345 Apr 9, 2025
5d0296f
Revert "update jit_ext access to torchfn_to_thunder registry : test"
kshitij12345 Apr 10, 2025
dedab03
temp commit
kshitij12345 Apr 15, 2025
efaae1d
Merge branch 'main' of github.com:Lightning-AI/lightning-thunder into…
kshitij12345 Apr 24, 2025
ddcf208
update to manual decomp
kshitij12345 Apr 24, 2025
6a6bf11
add manual grad rule
kshitij12345 Apr 24, 2025
2a8ea02
update
kshitij12345 May 14, 2025
9490cac
update - clean-up
kshitij12345 May 14, 2025
bd1ecbb
update attrs on DTensorProxy
kshitij12345 May 14, 2025
5d1f20b
Merge branch 'main' of github.com:Lightning-AI/lightning-thunder into…
kshitij12345 May 14, 2025
255c82d
remove debug change
kshitij12345 May 15, 2025
dba02d2
remove unused imports
kshitij12345 May 15, 2025
e8f6d0b
remove unused import
kshitij12345 May 15, 2025
83ae80a
update function name
kshitij12345 May 15, 2025
4cd9bec
cotangent metadata check initial support
kshitij12345 May 16, 2025
b49df80
address review : p1
kshitij12345 May 16, 2025
b206632
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 16, 2025
b70238a
address review
kshitij12345 May 16, 2025
1ff76b3
Merge branch 'dtensor-init-support' of https://github.com/kshitij1234…
kshitij12345 May 16, 2025
3e295da
update and refactor
kshitij12345 May 16, 2025
8f4a029
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion thunder/__init__.py
Original file line number Diff line number Diff line change
@@ -413,7 +413,13 @@ def _alias_tensor_of_args_kwargs_dict(*args, **kwargs) -> dict[int, list[int]]:
data_ptr_to_tensor_group_index = {}
tensor_group_index_to_tensor_indices = defaultdict(list)
for idx, t in enumerate(flat_args):
if pytorch.is_tensor(t) and t.layout == pytorch.strided:
# Using type(t) is pytorch.Tensor as TensorSubclasses don't support calling
# data_ptr().
# Eg. RuntimeError: Attempted to access the data pointer on an invalid python storage. (data_ptr access on TensorSubclass)
#
# isinstance(t, pytorch.Tensor) or pytorch.is_tensor(t) will match all Tensor objects including
# subclasses.
if type(t) is pytorch.Tensor and t.layout == pytorch.strided:
data_ptr = t.untyped_storage().data_ptr()
if data_ptr not in data_ptr_to_tensor_group_index:
data_ptr_to_tensor_group_index[data_ptr] = len(data_ptr_to_tensor_group_index)
10 changes: 10 additions & 0 deletions thunder/core/codeutils.py
Original file line number Diff line number Diff line change
@@ -152,6 +152,11 @@ def to_printable(
if isinstance(x, ProxyInterface):
return x

from thunder.torch.experimental.dtensor_codeutils import populate_object_ctx_for_dtensor_spec

if populate_object_ctx_for_dtensor_spec(x, object_ctx):
return x

if dataclasses.is_dataclass(x):
# Add `class` to the object_ctx so that we can reuse it during the trace execution.
if isinstance(x, type): # dataclass type
@@ -236,6 +241,11 @@ def prettyprint(
if isinstance(x, ContextObject):
return m(x.name)

from thunder.torch.experimental.dtensor_codeutils import prettyprint_dtensor_spec

if (dtensor_repr := prettyprint_dtensor_spec(x)) != "":
return m(dtensor_repr)

if dataclasses.is_dataclass(x):
# For a dataclass instance of class
# class MyContainer:
16 changes: 16 additions & 0 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
@@ -74,6 +74,10 @@
from thunder.torch import _torch_to_thunder_function_map
from thunder.clang import _clang_fn_set
from thunder.core.pytree import tree_map, tree_iter
from thunder.torch.experimental.dtensor_torch_and_prims import register_dtensor_torch_and_prims

# TODO: Find a better place to register these ops (mostly in thunder/torch/__init__.py but without cyclical dependency).
register_dtensor_torch_and_prims()
Comment on lines +77 to +80
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit-picking: Potential necessity of torch.distributed check.
Probably I just don't understand the implementation correctly but I'm not convinced with this registration's "safety" when torch.distributed is not available because IIUC this registration really uses DTensor and run some "meta" computation.

Though in most cases the submodule itself should be available so I'm not that strongly concerned.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is true for other files as well. Will have a look later, thanks!


#
# jit_ext.py implements extensions of thunder's interpreter
@@ -273,9 +277,15 @@ def proxify(self, value: WrappedValue) -> Any:

if p is not uvalue:
value.register_proxy(p)

from thunder.torch.experimental.dtensor_proxy import is_dtensor_proxy
from thunder.torch.experimental import dtensor_torch_and_prims

# TODO: other caching modes
co: CACHE_OPTIONS = get_cache_option()
if co is CACHE_OPTIONS.CONSTANT_VALUES:
if is_dtensor_proxy(p):
self.add_constraint((dtensor_torch_and_prims.check_dtensor_spec_repr, p, uvalue._spec))
self.add_constraint((clang.check_tensor_shape_and_metadata, p))
elif co is CACHE_OPTIONS.SYMBOLIC_VALUES:
# TODO: establish guarding logic to allow non-broadcast shape change
@@ -1840,6 +1850,12 @@ def is_variableified_tensorproxy(v: Variable | Proxy) -> Proxy:
if isinstance(s, Proxy):
unpack(s)

# Add checks for local tensor, mesh and placment of a DTensor
from thunder.torch.experimental.dtensor_torch_and_prims import handle_check_dtensor_spec_in_prologue

if handle_check_dtensor_spec_in_prologue(prim, prologue_trace, args):
continue

prim(*args)

cache_info = thunder._get_cache_info()
3 changes: 2 additions & 1 deletion thunder/core/prims.py
Original file line number Diff line number Diff line change
@@ -1834,7 +1834,8 @@ def _get_grad_meta(a: Number | NumberProxy | TensorProxy, /) -> Number | TensorP
utils.check_type(a, (Number, NumberProxy, TensorProxy))

if isinstance(a, TensorProxy):
return TensorProxy(like=a)
# NOTE: `a` could be a TensorProxy subclass and it's type should be preserved.
return type(a)(like=a)

# NOTE a is a Number in this branch
return numberproxy(pytype(a), 0)
6 changes: 6 additions & 0 deletions thunder/core/proxies.py
Original file line number Diff line number Diff line change
@@ -2059,6 +2059,12 @@ def proxy(x: Any, *, name: str | None = None, history: None | tuple = None) -> A
if x is ...:
return AnyProxy(x, name=name, history=history)

# Import here to avoid cyclical dependency.
from thunder.torch.experimental.dtensor_proxy import proxify_dtensor

if (dtensor_proxy := proxify_dtensor(x, name, history)) is not None:
return dtensor_proxy

if isinstance(x, torch.Tensor):
return tensorproxy(x, name=name, history=history)

3 changes: 2 additions & 1 deletion thunder/core/transforms.py
Original file line number Diff line number Diff line change
@@ -3130,7 +3130,8 @@ def forward_and_backward_from_trace(trace: Trace, torch_autograd=False) -> Forwa

def ones_like(x):
if isinstance(x, TensorProxy):
return full_like(x, fill_value=1)
# NOTE: x could be a subclass of TensorProxy and that should be preserved.
return type(x)(like=x)
elif isinstance(x, NumberProxy):
return type(x.value)(1)
else:
9 changes: 8 additions & 1 deletion thunder/dynamo/utils.py
Original file line number Diff line number Diff line change
@@ -6,12 +6,18 @@
import inspect
import itertools
import copy
from types import NoneType
from collections import defaultdict

import torch
from torch.nn.modules.module import _addindent
from torch._subclasses.fake_tensor import FakeTensor

if torch.distributed.is_available():
from torch.distributed.tensor import DTensor
else:
DTensor = NoneType

from thunder.torch.default_torch_ops import torch_auto_registered_ops
from thunder.torch import _torch_to_thunder_function_map
from thunder.torch.langctx import torchctx
@@ -507,7 +513,8 @@ def _get_storage_shape(t: torch.Tensor):


def _get_min_and_val(t: torch.Tensor) -> tuple[Number | None, Number | None]:
if isinstance(t, FakeTensor) or t.device.type == "meta" or t.numel() == 0:
# We assume that for TensorSubclass, `aminmax` is not supported which is true for FakeTensor and DTensor.
if (isinstance(t, torch.Tensor) and type(t) is not torch.Tensor) or t.device.type == "meta" or t.numel() == 0:
return None, None
if t.dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz):
t = t.to(torch.float32)
6 changes: 6 additions & 0 deletions thunder/executors/torch_autograd.py
Original file line number Diff line number Diff line change
@@ -437,6 +437,12 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat
# We only want the forward function to be called with `te.fp8_autocast` manager.
bw_extrace._include_te_fp8_autocast = False

# Should this be a post optmization transform?
# We only want to apply it on backward trace.
from thunder.torch.experimental.dtensor_utils import check_dtensor_cotangent_metadata_in_backward

bw_extrace = check_dtensor_cotangent_metadata_in_backward(bw_extrace)

if len(bw_extrace.bound_symbols) == 1:
# only return, no unpacking, so no gradient is calculated
bw_extrace = None
115 changes: 115 additions & 0 deletions thunder/tests/distributed/test_dtensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import unittest

import pytest
import torch

if not torch.distributed.is_available():
pytest.skip(allow_module_level=True)

from thunder.dynamo import thunderfx
import thunder

from thunder.tests.distributed.helper import DistributedParallelTestCase
from torch.distributed._tensor import DeviceMesh, Shard, distribute_tensor
from torch.distributed.tensor.placement_types import Placement, Shard, Replicate


@unittest.skipUnless(
torch.cuda.is_available() and torch.distributed.is_nccl_available(),
"DTensor test requires CUDA and NCCL `torch.distributed` backend",
)
class DTensorTest(DistributedParallelTestCase):
def test_dtensor_basic_op(self):
num_devices = self.world_size
mesh = DeviceMesh("cuda", list(range(num_devices)))

dim_size = 16

def _helper(fn, in_dtensor, w_dtensor):
expected = torch.compile(fn)(in_dtensor, w_dtensor)
tmodel = thunder.jit(fn)
actual = tmodel(in_dtensor, w_dtensor)

torch.testing.assert_close(actual, expected)

g_o = distribute_tensor(torch.ones(dim_size, dim_size), mesh, [Shard(0)])
expected_g = torch.autograd.grad(
expected,
(in_dtensor, w_dtensor),
g_o,
)
actual_g = torch.autograd.grad(actual, (in_dtensor, w_dtensor), g_o)

torch.testing.assert_close(actual_g, expected_g)

w_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=True), mesh, [Shard(0)])
in_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=True), mesh, [Shard(0)])

# Verify torch API works
_helper(lambda x, w: torch.mul(x, w), in_dtensor, w_dtensor)

# Verify calling method works
_helper(lambda x, w: torch.Tensor.mul(x, w), in_dtensor, w_dtensor)

# # Verify calling special method works
_helper(lambda x, w: x * w, in_dtensor, w_dtensor)

def test_dtensor_unsupported(self):
num_devices = self.world_size
mesh = DeviceMesh("cuda", list(range(num_devices)))

dim_size = 16

w_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=True), mesh, [Shard(0)])

in_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=True), mesh, [Shard(0)])

def fn(x, w):
return torch.div(x, w)

tmodel = thunder.jit(fn)
with pytest.raises(AssertionError):
tmodel(in_dtensor, w_dtensor)

def fn(x, w):
return x / w

tmodel = thunder.jit(fn)
with pytest.raises(AssertionError):
tmodel(in_dtensor, w_dtensor)

def test_dtensor_unsupported_mixed_input(self):
num_devices = self.world_size
mesh = DeviceMesh("cuda", list(range(num_devices)))

dim_size = 16

def fn(x, w):
return torch.div(x, w)

w = torch.randn(dim_size, dim_size, requires_grad=True)

in_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=True), mesh, [Shard(0)])

tmodel = thunder.jit(fn)
with pytest.raises(AssertionError):
tmodel(in_dtensor, w)

def test_dtensor_incorrect_cotangent(self):
num_devices = self.world_size
mesh = DeviceMesh("cuda", list(range(num_devices)))

dim_size = 16

w_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=True), mesh, [Shard(0)])
in_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=True), mesh, [Shard(0)])

def fn(x, w):
return torch.mul(x, w)

tmodel = thunder.jit(fn)
actual = tmodel(in_dtensor, w_dtensor)
g_o = distribute_tensor(torch.ones(dim_size, dim_size), mesh, [Shard(1)])

with pytest.raises(RuntimeError, match="has changed for cotangent between tracing and runtime"):
torch.autograd.grad(actual, (in_dtensor, w_dtensor), g_o)
26 changes: 24 additions & 2 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -145,6 +145,7 @@ def __init__(
is_prim: bool = False,
tags: None | list[Any] = None,
out_of_place: Symbol | None = None,
allow_tensor_subclass_proxy: bool = False,
):
self.torchfns = torchfns
self.is_method = is_method or (method_name is not None)
@@ -157,9 +158,30 @@ def __init__(
self.tags = tags
self.out_of_place = out_of_place

# This flag is used to enable/disable a torchsymbol to accept
# TensorProxy subclass as input (eg. DTensorProxy).
# By default, this is `False` as we don't want general `torchsymbol`
# which are meant for TensorProxy to accept DTensorProxy.
self.allow_tensor_subclass_proxy = allow_tensor_subclass_proxy

def __call__(self, fn: Callable) -> Symbol:
_fn = langctx(Languages.TORCH)(fn)

if not self.allow_tensor_subclass_proxy:

@wraps(_fn)
def wrapper(*args, **kwargs):
filter_tensor_proxies = list(
filter(lambda t: isinstance(t, TensorProxy), tree_flatten((args, kwargs))[0])
)
assert all(
map(lambda t: type(t) is TensorProxy, filter_tensor_proxies)
), f"Expected all inputs to be TensorProxy but found {list(map(lambda t: type(t), filter_tensor_proxies))}"
return _fn(*args, **kwargs)

else:
wrapper = _fn

id: str
if self.id is None:
name = fn.__name__
@@ -184,10 +206,10 @@ def __call__(self, fn: Callable) -> Symbol:

if self.is_prim:
sym = Symbol(
name=fn.__name__, meta=langctx(Languages.PRIMS)(_fn), id=id, is_prim=self.is_prim, tags=self.tags
name=fn.__name__, meta=langctx(Languages.PRIMS)(wrapper), id=id, is_prim=self.is_prim, tags=self.tags
)
else:
sym = Symbol(name=fn.__name__, meta=_fn, id=id, is_prim=self.is_prim, tags=self.tags)
sym = Symbol(name=fn.__name__, meta=wrapper, id=id, is_prim=self.is_prim, tags=self.tags)

if self.is_method:
method_name: str = self.method_name if self.method_name is not None else fn.__name__
31 changes: 31 additions & 0 deletions thunder/torch/experimental/dtensor_codeutils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Any
from torch.distributed.tensor._dtensor_spec import DTensorSpec, DeviceMesh, TensorMeta
from torch.distributed.tensor import DeviceMesh, Partial, Placement, Replicate, Shard


def populate_object_ctx_for_dtensor_spec(x: Any, object_ctx: dict[str, Any]) -> bool:
"""
Populate object context for DTensorSpec.

..note::
This function will mutate the `object_ctx`

Returns:
bool: True if `x` is DTensorSpec (and also updates `object_ctx`) otherwise False.
"""
if isinstance(x, DTensorSpec):
object_ctx["DTensorSpec"] = DTensorSpec
object_ctx["DeviceMesh"] = DeviceMesh
object_ctx["Placement"] = Placement
object_ctx["Replicate"] = Replicate
object_ctx["Shard"] = Shard
object_ctx["Partial"] = Partial
object_ctx["TensorMeta"] = TensorMeta
return True
return False


def prettyprint_dtensor_spec(x):
if isinstance(x, DTensorSpec):
return x.__repr__()
return ""
Loading