Skip to content

Trace Transform for Tensor Wrapper Subclasses #1883

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 38 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
398ae2a
Proxy and prims
crcrpar Dec 23, 2024
0bf415b
remove unnecessary history
crcrpar Jan 29, 2025
3b7c0fc
prim should live with not great printing of class
crcrpar Jan 29, 2025
9a07d68
core aten ops
crcrpar Dec 25, 2024
00e9d01
trace transform of tensor wrapper subclass
crcrpar Dec 23, 2024
6591ef6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 30, 2024
2cd32de
no pretty prints for flatten and unflatten
crcrpar Jan 29, 2025
c5b7fe6
trace transform of tensor wrapper subclass
crcrpar Dec 23, 2024
8bcb2fd
updates for MLP with `torchao.float8`
crcrpar Dec 25, 2024
b506502
check tensor attrs of tensor wrapper subclasses in prologue
crcrpar Dec 28, 2024
c6e58ff
print type_string of tensor attributes
crcrpar Dec 29, 2024
beb2f7d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 29, 2024
5d83cf1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 2, 2025
348c531
update stride conversion
crcrpar Feb 19, 2025
c170989
activate striedes checks of scaled_mm inputs
crcrpar Feb 20, 2025
97e03b0
post init refactoring
crcrpar Feb 27, 2025
1aa342d
refactor dunder call
crcrpar Feb 27, 2025
d20f4e9
cast `torch.(dtype|device)` to thunder.(devices|dtypes)`
crcrpar Feb 27, 2025
3de2ce2
flatten backward traces
crcrpar Feb 27, 2025
40689ef
add core aten version of abs
crcrpar Feb 27, 2025
498b4f1
allow `memory_format` in clone
crcrpar Feb 28, 2025
c8d6961
maintain map from proxy to fake tensor strides
crcrpar Mar 1, 2025
2a0abdb
call backward
crcrpar Mar 1, 2025
04c51ca
use strides info of `saved_for_backward`
crcrpar Mar 2, 2025
962d4a8
`AdHocExecutor` -> `TemporaryExecutor`
crcrpar Mar 2, 2025
b9657dc
After 1500 my work would never work.
crcrpar Mar 2, 2025
abde3e3
bypass evaluation of bsym if executor is torchex or temporary
crcrpar Mar 3, 2025
b9c15c0
`ad_hoc_executor` -> `temporary_executor`
crcrpar Mar 3, 2025
c05e916
improve `shallow_copy` for tensor subclass
crcrpar Mar 3, 2025
7d05d76
allow name duplicates for `SubclassTensorProxy`
crcrpar Mar 6, 2025
8ec75c4
convert Enum and dataclass to proxy
crcrpar Mar 6, 2025
224fd34
specialize checker for clone
crcrpar Mar 6, 2025
e62234e
remove creating clone of subclasstensorproxy
crcrpar Mar 6, 2025
4b10385
bind_postprocess for `torchex.unflatten_tensor_subclass`
crcrpar Mar 7, 2025
bdd4374
Do `updated_bsym.sym(*updated_bsym.args, **updated_bsym.kwargs)`
crcrpar Mar 9, 2025
7c02614
allow memory_format arg
crcrpar Mar 10, 2025
186a298
update tests
crcrpar Mar 10, 2025
5283902
use gelu w/o tanh approximate as it makes traces shorter
crcrpar Mar 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/reference/transforms/index.rst
Original file line number Diff line number Diff line change
@@ -8,3 +8,4 @@ thunder.transforms

MaterializationTransform
ConstantFolding
unroll_tensor_subclasses
10 changes: 8 additions & 2 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
@@ -73,6 +73,7 @@
from thunder.core.interpreter import print_interpreter_log, print_to_log
from thunder.core.jit_ext import thunder_general_jit
from thunder.executors.torch_autograd import split_forward_backward, connect_to_autograd
from thunder.transforms.tensor_wrapper_subclass import unroll_tensor_subclasses

# NOTE This import is intentionally pytorch so that it thunder.torch doesn't import this
import torch as pytorch
@@ -372,7 +373,7 @@ def _alias_tensor_of_args_kwargs_dict(*args, **kwargs) -> dict[int, list[int]]:
data_ptr_to_tensor_group_index = {}
tensor_group_index_to_tensor_indices = defaultdict(list)
for idx, t in enumerate(flat_args):
if pytorch.is_tensor(t) and t.layout == pytorch.strided:
if type(t) in {pytorch.Tensor, pytorch.nn.Parameter} and t.layout == pytorch.strided:
data_ptr = t.untyped_storage().data_ptr()
if data_ptr not in data_ptr_to_tensor_group_index:
data_ptr_to_tensor_group_index[data_ptr] = len(data_ptr_to_tensor_group_index)
@@ -479,7 +480,7 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com

prologue_traces += transform_for_execution(
prologue_trc,
executors_list=(pythonex,),
executors_list=(pythonex, get_executor("torch")),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be strict and precise, I want to include torchex only when prologue takes SubclassTensorProxys because torchex is only here for tensor_subclass.__tensor_flatten__() used inside prologue.

use_del_last_used=False,
)
prologue_trc = prologue_traces[-1]
@@ -497,12 +498,14 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com
computation_trc = dce(computation_trc)
computation_traces.append(computation_trc)

_unroll_tensor_subclass_applied: bool = False
backward_trc = None
if not cd.disable_torch_autograd_support:
tensor_cls = (pytorch.Tensor, TensorProxy)
requires_grad = any(isinstance(arg, tensor_cls) and arg.requires_grad for arg in computation_trc.args)

if requires_grad:
_unroll_tensor_subclass_applied = True
# Currently split_forward_backward also includes
# transform_for_execution and various sorting of symbols,
# applying transform_for_execution after this would be
@@ -513,6 +516,9 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com
# Note computation_trc and backward_trc have been appended to cs.last_(backward_)traces
# by split_forward_backward

if not _unroll_tensor_subclass_applied:
computation_trc, _ = unroll_tensor_subclasses(computation_trc)

if backward_trc is None:
from thunder.executors.passes import transform_for_execution as transform_for_execution_pass
from thunder.executors.passes import _transform_for_operator_executor_execution
4 changes: 3 additions & 1 deletion thunder/clang/__init__.py
Original file line number Diff line number Diff line change
@@ -22,7 +22,9 @@
NumberLike,
NumberProxy,
Proxy,
SubclassTensorProxy,
TensorProxy,
proxy,
pytype,
pyval,
)
@@ -62,7 +64,7 @@ def __call__(self, fn: Callable) -> Callable:

# Checks a tensor's shape and metadata (for use with cache check)
@clangop()
def check_tensor_shape_and_metadata(t: TensorProxy, /) -> None:
def check_tensor_shape_and_metadata(t: TensorProxy | SubclassTensorProxy, /) -> None:
return prims.check_tensor_shape_and_metadata(
t,
# replace Proxy entries with `-1`s as wild card, as we any value is
119 changes: 96 additions & 23 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
@@ -40,6 +40,9 @@
ProxyInterface,
ProxyTag,
TensorProxy,
FutureTensorProxy,
SubclassTensorProxy,
make_proxy_name,
Variable,
is_proxy_name_available,
proxy,
@@ -737,6 +740,10 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar
So far, non-tensor ``ctx`` attributes seem to be folded into a trace.
"""
from thunder.core.baseutils import check, sequencify
from thunder.core.trace_interpreter import interpret_trace
from thunder.core.transforms import dce
from thunder.core.pytree import tree_flatten, tree_unflatten
from thunder.extend import TemporaryExecutor

custom_autograd_function_cls = unwrap(obj)
custom_forward = custom_autograd_function_cls.forward
@@ -752,25 +759,37 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar
if trace_of_fwd is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
return trace_of_fwd

# Forward.
# augmented forward trace.
unwrapped_custom_forward_args = tree_map(lambda a: unwrap(a), args)
trace_of_fwd._siginfo = SigInfo.from_name_and_args(
custom_autograd_function_cls.__name__,
unwrapped_custom_forward_args,
)
trace_of_fwd.args = unwrapped_custom_forward_args
unpack_bsyms = [
prims.unpack_trivial.bind(a, name=a.name, output=a)
for a in filter(lambda a: isinstance(a, Proxy), trace_of_fwd.args)
for a in filter(lambda a: isinstance(a, Proxy), unwrapped_custom_forward_args)
]
trace_of_fwd.bound_symbols = unpack_bsyms + trace_of_fwd.bound_symbols

@wraps(trace_of_fwd.python_callable())
augmented_bsym_output: tuple[tuple[TensorProxy, ...], tuple[TensorProxy, ...]] = (
tuple(sequencify(trace_of_fwd.output)),
ctx_proxy.saved_tensors,
)
trace_of_augmented_fwd = TraceCtx()
trace_of_augmented_fwd.bound_symbols.extend((unpack_bsyms + trace_of_fwd.bound_symbols)[:-1])
with tracectx(trace_of_augmented_fwd):
prims.python_return(augmented_bsym_output)
trace_of_augmented_fwd._siginfo = SigInfo.from_name_and_args(
custom_autograd_function_cls.__name__, unwrapped_custom_forward_args
)
trace_of_augmented_fwd.args = unwrapped_custom_forward_args
trace_of_augmented_fwd = dce(trace_of_augmented_fwd)
_, spec_of_fwd_output = tree_flatten(trace_of_fwd.output)

@wraps(trace_of_augmented_fwd.python_callable())
def core_of_forward(*args, **kwargs):
return thunder.core.trace_interpreter.interpret_trace(trace_of_fwd, *args, **kwargs)
output, _ = interpret_trace(trace_of_augmented_fwd, *args, **kwargs)
flat_output, _ = tree_flatten(output)
return tree_unflatten(flat_output, spec_of_fwd_output)

custom_fwd_sym = get_jit_ctx().ad_hoc_executor.register_operator(
trace_of_fwd._siginfo.name,
temporary_executor: TemporaryExecutor = get_jit_ctx().ad_hoc_executor
custom_fwd_sym = temporary_executor.register_operator(
custom_autograd_function_cls.__name__,
like=core_of_forward,
)
unwrapped_forward_result = custom_fwd_sym(*unwrapped_custom_forward_args)
@@ -779,17 +798,6 @@ def core_of_forward(*args, **kwargs):
provenance=ProvenanceRecord(PseudoInst.LOOKASIDE, inputs=[obj.provenance, fwd_output_provenance]),
)

augmented_bsym_output: tuple[tuple[TensorProxy, ...], tuple[TensorProxy, ...]] = (
tuple(sequencify(trace_of_fwd.output)),
ctx_proxy.saved_tensors,
)
trace_of_augmented_fwd = TraceCtx()
trace_of_augmented_fwd.bound_symbols.extend(trace_of_fwd.bound_symbols[:-1])
with tracectx(trace_of_augmented_fwd):
prims.python_return(augmented_bsym_output)
trace_of_augmented_fwd._siginfo = SigInfo.from_name_and_args(custom_fwd_sym.name, unwrapped_custom_forward_args)
trace_of_augmented_fwd.args = unwrapped_custom_forward_args

# Backward definition
custom_backward = custom_autograd_function_cls.backward
grads = tree_map(
@@ -818,6 +826,7 @@ def core_of_forward(*args, **kwargs):
ctx_proxy.saved_consts + ctx_proxy.saved_tensors + grads,
)
bwd_trace_impl.args = tuple(ctx_proxy.saved_consts + ctx_proxy.saved_tensors + grads)
bwd_trace_impl = dce(bwd_trace_impl)

@wraps(bwd_trace_impl.python_callable())
def bwd_impl_callable(*args, **kwargs):
@@ -843,6 +852,24 @@ def grad_transform(*args, **kwargs):
execution_transform=core_of_forward,
grad_transform=grad_transform,
)

added_bsym: BoundSymbol = get_jit_ctx().computation_trace.scopes[-1][-1]
import_ctx, call_ctx, object_ctx = {}, {}, {}
for bsym in trace_of_fwd.bound_symbols:
cur_import_ctx, cur_call_ctx, cur_object_ctx = bsym.gather_ctxs()
import_ctx.update(cur_import_ctx)
call_ctx.update(cur_call_ctx)
object_ctx.update(cur_object_ctx)

if import_ctx:
added_bsym._import_ctx.update(import_ctx)
if call_ctx:
if added_bsym._call_ctx is not None:
added_bsym._call_ctx.update(call_ctx)
else:
added_bsym._call_ctx = call_ctx
if object_ctx:
added_bsym._object_ctx.update(object_ctx)
return forward_result


@@ -1057,6 +1084,42 @@ def add_input_output_proxy_name(p):
return res


@register_general_jit_lookaside(torch.Tensor._make_wrapper_subclass)
def _make_wrapper_subclass(
cls: torch._C._TensorMeta,
size: Sequence[int],
strides: Sequence[int] | None = None,
storage_offset: int | None = None,
memory_format: torch.memory_format | None = None,
dtype: torch.dtype | None = None,
layout: torch.layout | None = torch.strided,
device: torch.device | None = None,
pin_memory: bool = False,
requires_grad: bool = False,
dispatch_sizes_strides_policy: str | None = None,
dispatch_device: bool = False,
dispatch_layout: bool = False,
_extra_dispatch_keys: torch.DispatchKeySet | None = None,
storage_size: int | None = None,
):
ucls = unwrap(cls)
usize = unwrap(size)
udtype = unwrap(dtype)
udevice = unwrap(device)
urequires_grad = unwrap(requires_grad)

subclass = SubclassTensorProxy(
None,
shape=usize,
device=udevice,
dtype=udtype,
requires_grad=urequires_grad,
history=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance]),
subclass_type=ucls,
)
return wrap(subclass, provenance=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance]))


# Adds proxy methods
# NOTE These methods map to themselves, which prevents the interpreter from looking into them
# This is OK because these methods are written in a tracing-safe manner, and trying to
@@ -1793,9 +1856,12 @@ def is_variableified_tensorproxy(v: Variable | Proxy) -> Proxy:

with tracectx(prologue_trace):
for prim, *args in ctx._constraints:
subclass_tensor: SubclassTensorProxy | None = None
for a in args:
if isinstance(a, Proxy):
unpack(a)
if isinstance(a, SubclassTensorProxy):
subclass_tensor = a
# unpacking Proxy in TensorProxy.shape which is used in `check_tensor_shape_and_metadata`
if prim == clang.check_tensor_shape_and_metadata:
for s in a.shape:
@@ -1804,6 +1870,13 @@ def is_variableified_tensorproxy(v: Variable | Proxy) -> Proxy:

prim(*args)

if isinstance(subclass_tensor, SubclassTensorProxy):
for t in prims.flatten_tensor_subclass(subclass_tensor):
for s in t.shape:
if isinstance(s, Proxy):
unpack(s)
prim(t)

cache_info = thunder._get_cache_info()
# assert len of cache info to ensure that we're not missing anything?
if cache_info:
157 changes: 155 additions & 2 deletions thunder/core/prims.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
from __future__ import annotations
from enum import auto, Enum
from numbers import Number
from functools import reduce, wraps
import operator
import builtins
import math
from types import NoneType
from typing import Union, Type, Any, List, Dict, Tuple, Optional
from typing import Union, Type, Any, List, Dict, Tuple, Optional, TYPE_CHECKING
from collections.abc import Callable
from collections.abc import Callable, Hashable, Sequence

import torch

from thunder.core.langctxs import LanguageContext, register_langctx, Languages, langctx

if TYPE_CHECKING:
from collections.abc import Iterable
from thunder.core.codeutils import ContextObject

#
# Creates and registers the torch language context
#
@@ -78,6 +83,7 @@ def register_method(method_name: str, method: Callable, /) -> None:
TupleProxy,
AnyProxy,
IntegerProxy,
SubclassTensorProxy,
)
import thunder.core.codeutils as codeutils
from thunder.core.codeutils import Printable
@@ -276,6 +282,10 @@ class PrimIDs(Enum):
COPY_ = auto()
#
SINK = auto()
# Tensor Subclasses methods
TENSOR_SUBCLASS_CTOR = auto()
FLATTEN_TENSOR_SUBCLASS = auto()
UNFLATTEN_TENSOR_SUBCLASS = auto()


class OpTags(Enum):
@@ -3676,7 +3686,11 @@ def transpose_meta(a: TensorProxy, /, permutation: tuple[int, ...]) -> TensorPro
view = make_prim(PrimIDs.VIEW, "view", meta=reshape_meta, tags=(OpTags.SHAPE_OP,))


def shallow_copy_meta(a: TensorProxy, /) -> TensorProxy:
def shallow_copy_meta(a: TensorProxy | SubclassTensorProxy, /) -> TensorProxy:
if isinstance(a, SubclassTensorProxy):
shallow = SubclassTensorProxy(like=a)
shallow.copy_attributes_from(a)
return shallow
return TensorProxy(like=a)


@@ -4187,3 +4201,142 @@ def sink_meta(*args, **kwargs):

# TODO do we want another tag to remove this after prologue is constructed?
sink = make_prim(PrimIDs.SINK, "sink", meta=sink_meta, tags=(OpTags.DONT_DCE,))


def tensor_subclass_ctor_meta(
cls, name, shape, device, dtype, requires_grad, tensors, non_tensors
) -> SubclassTensorProxy:
s = SubclassTensorProxy(
name,
subclass_type=cls,
shape=shape,
device=device,
dtype=dtype,
requires_grad=requires_grad,
tensors=tensors,
non_tensors=non_tensors,
)
return s


def get_nested_types(collection):
collection = utils.sequencify(collection)
types_set = {type(t) for t in collection}

def check_types(coll):
for item in coll:
types_set.add(type(item))
# Check if the item is a nested collection
if baseutils.is_collection(item):
# If it's a dictionary, check its values
if isinstance(item, dict):
check_types(item.values())
# Recursively check nested collections
else:
check_types(item)

check_types(collection)
return tuple(types_set)


def filter_types_for_tensor_wrapper_subclass(types: tuple[Any, ...]) -> tuple[Any, ...]:
return tuple(
filter(
lambda t: (
t.__module__ != "builtins"
and t != Number
# note(crcrpar): maybe `thunder.core`?
and not t.__module__.startswith("thunder.")
and not t.__module__.startswith("torch.")
),
types,
)
)


def bind_postprocess_of_tensor_subclass_ctor(bsym: BoundSymbol) -> None:
cls = bsym.args[0]
non_tensors = bsym.args[-1]

filtered_types: tuple[Any, ...] = (cls,)
if non_tensors:
types = get_nested_types(non_tensors)
filtered_types += filter_types_for_tensor_wrapper_subclass(types)
new_imports = {t.__name__: t for t in filtered_types}
bsym._import_ctx.update(new_imports)


tensor_subclass_ctor = make_prim(
PrimIDs.TENSOR_SUBCLASS_CTOR,
"tensor_subclass_ctor",
meta=tensor_subclass_ctor_meta,
_bind_postprocess=bind_postprocess_of_tensor_subclass_ctor,
)


# NOTE(crcrpar): The behavior is different from PyTorch `subclass_tensor.__tensor_flatten__()`
# that returns a list of tensor attr names and a dict of const metadata. In Thunder traces,
# const values could be obviated and actual tensor proxies would be more useful
# than tensor attr names.
def flatten_tensor_subclass_meta(t: SubclassTensorProxy) -> tuple[TensorProxy, ...]:
tensor_attr_names, metadata = t.__tensor_flatten__()
tensors = tuple(getattr(t, name) for name in tensor_attr_names)
return tensors


flatten_tensor_subclass = make_prim(
PrimIDs.FLATTEN_TENSOR_SUBCLASS,
"flatten_tensor_subclass",
meta=flatten_tensor_subclass_meta,
)


def bind_postprocess_of_unflatten_tensor_subclass(bsym: BoundSymbol) -> None:
cls = bsym.args[0]
inner_tensors = bsym.args[1]
metadata = bsym.args[2]

filtered_types: tuple[Any, ...] = (cls,)
if metadata:
types = get_nested_types(list(metadata.values()))
filtered_types += filter_types_for_tensor_wrapper_subclass(types)
new_imports = {t.__name__: t for t in filtered_types}
bsym._import_ctx.update(new_imports)


def unflatten_tensor_subclass_meta(
tensor_subclass_type,
inner_tensors: dict[str, TensorProxy],
metadata: dict[str, Any],
) -> SubclassTensorProxy:
first_tensor: TensorProxy = list(inner_tensors.values())[0]
a = SubclassTensorProxy(
shape=first_tensor.shape,
device=first_tensor.device,
dtype=first_tensor.dtype,
requires_grad=first_tensor.requires_grad,
tensors=list(inner_tensors.values()),
non_tensors=list(metadata.values()),
subclass_type=tensor_subclass_type,
)
for name, value in inner_tensors.items():
setattr(a, name, value)
for name, value in metadata.items():
setattr(a, name, value)
return a


def unflatten_tensor_subclass_python_impl(
tensor_subclass_type,
inner_tensors: dict[str, TensorProxy],
metadata: dict[str, Any],
) -> torch.Tensor:
return tensor_subclass_type.__tensor_unflatten__(inner_tensors, metadata, -1, -1)


unflatten_tensor_subclass = make_prim(
PrimIDs.UNFLATTEN_TENSOR_SUBCLASS,
"unflatten_tensor_subclass",
meta=unflatten_tensor_subclass_meta,
_bind_postprocess=bind_postprocess_of_unflatten_tensor_subclass,
)
336 changes: 330 additions & 6 deletions thunder/core/proxies.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions thunder/core/pytree.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from enum import Enum
from functools import partial
from types import FunctionType
import dataclasses
@@ -64,6 +65,7 @@ def tree_flatten(args, namespace=OPTREE_NAMESPACE):
and not is_likely_from_collections_namedtuple(args)
and not dataclasses.is_dataclass(args)
and not type(args).__module__.startswith("torch.return_types")
and not issubclass(type(args), Enum)
):
raise TypeError(f"tree_flatten of type {type(args)} is not supported.")
return optree.tree_flatten(args, none_is_leaf=True, namespace=namespace)
8 changes: 5 additions & 3 deletions thunder/core/trace.py
Original file line number Diff line number Diff line change
@@ -178,9 +178,11 @@ def set_provenance(self, provenance: TraceProvenance) -> None:
# Methods related to name construction
#

def add_name(self, name: str) -> None:
def add_name(self, name: str, *, prefix: str | None = None) -> None:
from thunder.core.proxies import PREFIXES_ALLOW_NAME_DUPLICATES

baseutils.check(
name not in self.names,
name not in self.names or (prefix is not None and prefix in PREFIXES_ALLOW_NAME_DUPLICATES),
Comment on lines +181 to +185
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a program calls a custom torch.autograd.Function that makes a tensor subclass instance from torch.Tensors, the lookaside of torch.autograd.Function would be called. The lookaside creates a trace of that function, and the trace would have a BoundSymbol of a prim of tensor subclass ctor that calls SubclassTensorProxy.__init__ which calls Proxy.__init__ which calls TraceCtx.add_name.
The trace's bound symbols would be passed to OpExProcessor and their sym.meta would be evaluated thus that trace would try to add the same name for subclass tensor proxy at least twice. This special casing allows that duplication.

lambda: f"Trying to add the name {name} to a trace, but that name is already used",
)
self.names.add(name)
@@ -221,7 +223,7 @@ def _make_name(self, *, prefix: str | None = None, is_object_name: bool = False,
# just records the given name
def make_name(self, name: str | None = None, *, prefix: str | None = None) -> str:
if name is not None:
self.add_name(name)
self.add_name(name, prefix=prefix)
return name

return self._make_name(prefix=prefix)
7 changes: 5 additions & 2 deletions thunder/core/trace_interpreter.py
Original file line number Diff line number Diff line change
@@ -131,8 +131,11 @@ def add_to_swap_map(old, new):
old = old.replace(shape=new._shape)

if isinstance(new, VJPDual):
swap_map[variableify(new.primal)] = old
new.primal = old
# note(crcrpar): Without this sanity check, `subclass.__tensor_flatten__`,
# seems to cause `new.primal` == `old`, leading to a cycle in swapping.
if (key := variableify(new.primal)) != variableify(old):
Comment on lines +134 to +136
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to revisit this comment

swap_map[variableify(new.primal)] = old
new.primal = old
else:
assert isinstance(new, ProxyInterface), (old, new)
swap_map[variableify(new)] = old
2 changes: 1 addition & 1 deletion thunder/core/transform_common.py
Original file line number Diff line number Diff line change
@@ -168,7 +168,7 @@ def dce(trace: Trace, needed_proxies: None | set[Variable] = None) -> Trace:
# may mark some of the operation's outputs as unused
some_unused = False
for out in bsym.flat_proxy_outs:
if variableify(out) in needed_proxies and producer_map[out] == bsym:
if variableify(out) in needed_proxies and producer_map.get(out, None) == bsym:
needed = True
else:
some_unused = True
10 changes: 8 additions & 2 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
@@ -1715,13 +1715,19 @@ def trunc(a: TensorProxy | Number, *, fd: FusionDefinition, lc_to_nv_map: dict)
register_supported(PrimIDs.TRUNC, trunc, _elementwise_unary_check)


def clone(a: TensorProxy, *, fd: FusionDefinition, lc_to_nv_map: dict) -> Any:
def clone(a: TensorProxy, memory_format=torch.preserve_format, *, fd: FusionDefinition, lc_to_nv_map: dict) -> Any:
nva = getnv(a, fd, lc_to_nv_map)

return fd.ops.set(nva)


register_supported(PrimIDs.CLONE, clone, _elementwise_unary_check)
def _clone_check(a: TensorProxy, *, memory_format: torch.memory_format = torch.preserve_format) -> bool:
if memory_format not in (torch.preserve_format, torch.contiguous_format):
return False
return _elementwise_unary_check(a)


register_supported(PrimIDs.CLONE, clone, _clone_check)

#
# Elementwise binary operations
25 changes: 19 additions & 6 deletions thunder/executors/passes.py
Original file line number Diff line number Diff line change
@@ -61,13 +61,26 @@ def process_bsym(self, bsym):
self.add_bsyms_from_function(execution_transform, *bsym.args, **bsym.kwargs)
return
elif isinstance(ex, OperatorExecutor):
from thunder.extend import TemporaryExecutor

# NOTE execution_transform is None and the executor is an operator executor
# Calls the operator executor's operation
# TODO Instead of directly acquiring the symbol from the implmap, we probably
# want to hide this behind a function
op = ex.implmap[bsym.sym.id].symbol
self.add_bsyms_from_function(op, *bsym.args, **bsym.kwargs)
return
if ex.name == "torch" or isinstance(ex, TemporaryExecutor):
# For TorchExecutor, we can bypass the function call in add_bsyms_from_function
# and directly create the bound symbol
op = ex.implmap[bsym.sym.id].symbol

# Create a bound symbol directly without executing the function
new_bsym = op.bind(*bsym.args, **bsym.kwargs, output=bsym.output)
self.add_processed_bsyms([new_bsym])

# Set the result without actually executing the function
self.set_result(new_bsym.output)
return
else:
# For other OperatorExecutors, use the original approach
op = ex.implmap[bsym.sym.id].symbol
self.add_bsyms_from_function(op, *bsym.args, **bsym.kwargs)
return
elif isinstance(ex, FusionExecutor):
# NOTE execution_transform is None and the executor is a fusion executor
# Preserves the symbol as is (it will be handled in the fusion pass)
9 changes: 9 additions & 0 deletions thunder/executors/torch_autograd.py
Original file line number Diff line number Diff line change
@@ -223,6 +223,7 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat
from thunder.distributed.transforms import FSDPCommBucketing
from thunder.distributed.utils import sort_data_parallel_syncs, sort_waits, sort_communication_ops
from thunder.executors.passes import del_last_used, transform_for_execution
from thunder.transforms.tensor_wrapper_subclass import unroll_tensor_subclasses

utils.check(compile_data is not None, lambda: "`compile_data` is required")
# NOTE: This function is rather slow, so it's intended to be used
@@ -249,6 +250,11 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat
fw_traces = [fw_trace]
bw_traces = [bw_trace]

fw_trace, saved_proxy_for_bwd_to_strides = unroll_tensor_subclasses(
fw_trace, is_bwd_trace=False, proxy_to_strides=None
)
fw_traces.append(fw_trace)

from thunder.distributed import FSDPType

# only enable rematerialize_params_in_backward when using FSDP ZeRO3
@@ -353,6 +359,9 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat
if getattr(compile_data.fn, "use_fsdp", False):
bw_trace = _fsdp_comm_bucketing.apply_bucketing_to_backward_trace(bw_trace)

bw_trace, _ = unroll_tensor_subclasses(bw_trace, is_bwd_trace=True, proxy_to_strides=saved_proxy_for_bwd_to_strides)
bw_traces.append(bw_trace)

# Now we can run the optimization passes on the backward trace
# TODO Restore request for no rematerialization
bw_extrace = transform_for_execution(
6 changes: 6 additions & 0 deletions thunder/executors/torch_compile.py
Original file line number Diff line number Diff line change
@@ -56,6 +56,12 @@ def _to_torch(*args, **kwargs) -> Any:
if torch_op is None:
raise RuntimeError(f"op not found for {bsym.sym.name}")

# NOTE(crcrpar): Currently `ltorch.t` is mapped to `torchex.transpose`
# thus `args` needs to be updated to have dim0 and dim1
if bsym.sym.id == "torch.t":
utils.check(len(args) == 1, lambda: f"{bsym.sym.id} takes only one argument but {args=}")
args = args + (0, 1)

return torch_op(*args, **kwargs)

return _to_torch
195 changes: 192 additions & 3 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
@@ -17,8 +17,10 @@
import thunder.core.devices as devices
from thunder.core.devices import to_torch_device, to_device
import thunder.core.prims as prims
from thunder.core.proxies import NumberProxy, TensorProxy, FutureTensorProxy, pytype
from thunder.core.symbol import Symbol
from thunder.core.proxies import NumberProxy, TensorProxy, FutureTensorProxy, variableify, pytype, Proxy
from thunder.core.pytree import tree_flatten, tree_unflatten
from thunder.core.symbol import Symbol, BoundSymbol
from thunder.core.trace import TraceCtx, set_tracectx, reset_tracectx, from_trace
from thunder.distributed.prims import DistributedReduceOps
import thunder.distributed.prims as dist_prims
import thunder.core.utils as utils
@@ -33,7 +35,10 @@
)

if TYPE_CHECKING:
from typing import Any
from collections.abc import Iterable
from thunder.common import CompileData
from thunder.core.codeutils import ContextObject, Printable

ex = OperatorExecutor("torch", version=torch.__version__)
register_executor(ex)
@@ -451,7 +456,7 @@ def _empty_prims_transform(


def _clone_prims_transform(a: TensorLike, **kwargs) -> TensorLike:
return clone(a)
return clone(a, memory_format=kwargs.get("memory_format", torch.preserve_format))


def _tensor_from_sequence_prims_transform(
@@ -1414,13 +1419,44 @@ def _copy_with_setitem_impl(a, key, value):
#

matmul = _register_torch_operation("matmul")
_scaled_mm = _register_torch_operation("_scaled_mm")
outer = _register_torch_operation("outer")

_register_implementation(prims.matmul, matmul, checker=_always_executable)

_register_implementation(ltorch.matmul, matmul, checker=_always_executable)
_register_implementation(ltorch.outer, outer, checker=_always_executable)


def _scaled_mm_impl(
a: TensorLike,
b: TensorLike,
scale_a: TensorLike,
scale_b: TensorLike,
bias: TensorLike | None = None,
scale_result: TensorLike | None = None,
out_dtype: dtypeLike | None = None,
use_fast_accum: bool = False,
):

def is_row_major(mat: TensorLike) -> bool:
return mat.stride()[1] == 1 and mat.stride()[0] > 1

def is_column_major(mat: TensorLike) -> bool:
return mat.stride()[0] == 1 and mat.stride()[1] > 1

utils.check(is_row_major(a), lambda: f"`a` expected to be row-major but its stride is {a.stride()=}")
utils.check(is_column_major(b), lambda: f"`b` expected to be column-major but its stride is {b.stride()=}")
result_dtype: torch.dtype = to_torch_dtype(a.dtype if out_dtype is None else out_dtype)

return torch._scaled_mm(a, b, scale_a, scale_b, bias, scale_result, result_dtype, use_fast_accum)


_scaled_mm_impl = ex.register_operator("_scaled_mm_impl", like=ltorch._scaled_mm, fn=_scaled_mm_impl)
_register_implementation(ltorch._scaled_mm, _scaled_mm_impl, checker=_always_executable)
_register_implementation(ltorch.core_aten_scaled_mm, _scaled_mm_impl, checker=_always_executable)


#
# Normalization operations
#
@@ -2277,3 +2313,156 @@ def _shape_impl(t):

shallow_copy = ex.register_operator("shallow_copy", meta=prims.shallow_copy, fn=lambda x: x)
_register_implementation(prims.shallow_copy, shallow_copy, checker=_always_executable)


def _tensor_subclass_ctor(cls, name, shape, device, dtype, requires_grad, tensors, non_tensors):
new_non_tensors = []
for a in non_tensors:
if isinstance(a, dtypes.dtype):
new_non_tensors.append(to_torch_dtype(a))
elif isinstance(a, devices.Device):
new_non_tensors.append(to_torch_device(a))
else:
new_non_tensors.append(a)
return cls(*tensors, *new_non_tensors)


def _bind_postprocess_of_tensor_subclass_ctor(bsym: BoundSymbol) -> None:
from thunder.core.prims import get_nested_types, filter_types_for_tensor_wrapper_subclass

cls, _name, _shape, _device, _dtype, _requires_grad, _tensors, non_tensors = bsym.args
filtered_types = (cls,)
if non_tensors:
types = get_nested_types(non_tensors)
filtered_types += filter_types_for_tensor_wrapper_subclass(types)
new_imports = {t.__name__: t for t in filtered_types}
bsym._import_ctx.update(new_imports)


def printer_of_tensor_subclass_ctor(
bsym: BoundSymbol,
out_printables: Any,
arg_printables: Sequence[Printable],
kwarg_printables: dict[str, Printable],
) -> str | Iterable[str]:
from itertools import chain
from thunder.core import baseutils
from thunder.core import codeutils

baseutils.check(not kwarg_printables, lambda: f"No kwargs are supported but {kwarg_printables = }")

# NOTE(crcrpar): It's not a context but at the moment Tensor subclass is treated as `ContextObject`.
wrapped_cls: ContextObject | torch._C._TensorMeta = arg_printables[0]
if isinstance(wrapped_cls, torch._C._TensorMeta):
cls = wrapped_cls
else:
cls: torch._C._TensorMeta = wrapped_cls.obj
tensors, non_tensors = arg_printables[-2:]
new_non_tensors = []
for a in non_tensors:
if isinstance(a, dtypes.dtype):
new_non_tensors.append(dtypes.to_torch_dtype(a))
elif isinstance(a, devices.Device):
new_non_tensors.append(devices.to_torch_device(a))
else:
new_non_tensors.append(a)

arg_str = ", ".join(codeutils.prettyprint(x) for x in [*tensors, *new_non_tensors])
kwarg_str = ""

result_str: str
if bsym.output is None or (baseutils.is_collection(bsym.output) and len(bsym.output) == 0):
result_str = ""
else:
result_str = f"{codeutils.prettyprint(out_printables, literals_as_underscores=True)} = "

# Creates a comment describing the output
comment_str: str
if isinstance(bsym.output, Proxy):
comment_str = f" # {codeutils.prettyprint(out_printables, with_type=True)}"
else:
comment_str = ""

cls_with_module = f"{cls.__name__}"
s = f"{result_str}{cls_with_module}({arg_str}{', ' if (len(arg_str) > 0 and len(kwarg_str) > 0) else ''}{kwarg_str}){comment_str}"

if bsym.header:
header_lines = (
bsym.header
if isinstance(bsym.header, Sequence) and not isinstance(bsym.header, str)
else bsym.header.splitlines()
)
header_lines = (f"# {line}" for line in header_lines)
return chain(header_lines, [s])

return s


tensor_subclass_ctor = ex.register_operator(
"tensor_subclass_ctor",
meta=prims.tensor_subclass_ctor,
fn=_tensor_subclass_ctor,
bind_postprocess=_bind_postprocess_of_tensor_subclass_ctor,
python_printer=printer_of_tensor_subclass_ctor,
)
_register_implementation(prims.tensor_subclass_ctor, tensor_subclass_ctor, checker=_always_executable)


def flatten_tensor_subclass_impl(t):
tensor_attr_names, metadata = t.__tensor_flatten__()
tensors = tuple(getattr(t, name) for name in tensor_attr_names)
return tensors


flatten_tensor_subclass = ex.register_operator(
"flatten_tensor_subclass",
meta=prims.flatten_tensor_subclass.meta,
fn=flatten_tensor_subclass_impl,
)
_register_implementation(
prims.flatten_tensor_subclass,
flatten_tensor_subclass,
checker=_always_executable,
)


def unflatten_tensor_subclass_impl(
tensor_subclass_type: torch._C._TensorMeta,
inner_tensors: dict[str, TensorLike],
metadata: dict,
):
for key in metadata:
v = metadata[key]
if isinstance(v, dtypes.dtype):
metadata[key] = to_torch_dtype(v)
elif isinstance(v, devices.Device):
metadata[key] = to_torch_device(v)
return tensor_subclass_type.__tensor_unflatten__(inner_tensors, metadata, -1, -1)


def bind_postprocess_of_unflatten_tensor_subclass(bsym: BoundSymbol) -> None:
from thunder.core.prims import filter_types_for_tensor_wrapper_subclass, get_nested_types

cls = bsym.args[0]
_inner_tensors = bsym.args[1]
metadata = bsym.args[2]

filtered_types: tuple[Any, ...] = (cls,)
if metadata:
types = get_nested_types(list(metadata.values()))
filtered_types += filter_types_for_tensor_wrapper_subclass(types)
new_imports = {t.__name__: t for t in filtered_types}
bsym._import_ctx.update(new_imports)


unflatten_tensor_subclass = ex.register_operator(
"unflatten_tensor_subclass",
meta=prims.unflatten_tensor_subclass.meta,
fn=unflatten_tensor_subclass_impl,
bind_postprocess=bind_postprocess_of_unflatten_tensor_subclass,
)
_register_implementation(
prims.unflatten_tensor_subclass,
unflatten_tensor_subclass,
checker=_always_executable,
)
330 changes: 330 additions & 0 deletions thunder/tests/test_tensor_subclass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,330 @@
from __future__ import annotations
from typing import TYPE_CHECKING

from lightning_utilities.core.imports import package_available
import pytest
import torch
import torch.nn as nn
from torch.utils import _pytree as pytree

import thunder
from thunder.dynamo.compiler import ThunderCompiler
from thunder.tests.framework import (
DynamoThunderExecutor,
TorchExecutor,
instantiate,
nvFuserExecutor,
)
from thunder.tests.make_tensor import make_tensor

if TYPE_CHECKING:
from typing import Any


TORCHAO_AVAILABLE = package_available("torchao")


@torch._dynamo.allow_in_graph
class EncapsulateXandScale(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, scale: torch.Tensor):
return ScaleTensorSubclass(x, scale)

@staticmethod
def backward(ctx, grad):
return grad, None


def encapsulate_x_and_scale(x, scale) -> ScaleTensorSubclass:
return EncapsulateXandScale.apply(x, scale)


@torch._dynamo.allow_in_graph
class ToScaleTensorSubclass(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor):
return ScaleTensorSubclass.from_tensor(x)

@staticmethod
def backward(ctx, grad):
return grad


def to_scale_tensor_subclass(x: torch.Tensor) -> ScaleTensorSubclass:
return ToScaleTensorSubclass.apply(x)


class ScaleTensorSubclass(torch.Tensor):
_x: torch.Tensor
_scale: torch.Tensor
__slots__ = ["_x", "_scale"]

def __new__(cls, x: torch.Tensor, scale: torch.Tensor):
assert scale.numel() == 1, f"Invalid `scale`: {scale}"
dtype = x.dtype
device = x.device
self = torch.Tensor._make_wrapper_subclass(
cls,
x.size(),
dtype=dtype,
device=device,
# strides=x.stride(),
# storage_offset=x.storage_offset(),
# layout=x.layout,
requires_grad=x.requires_grad,
)
self._x = x
self._scale = scale

return self

# ref: https://github.com/albanD/subclass_zoo/blob/ec47458/base_tensor.py#L22
__torch_function__ = torch._C._disabled_torch_function_impl

def __repr__(self):
return f"ScaleTensorSubclass(dtype={self._x.dtype}, device={self._x.device}, x={self._x}, scale={self._scale})"

def __tensor_flatten__(self) -> tuple[list[str], dict[str, Any]]:
return ["_x", "_scale"], {}

@staticmethod
def __tensor_unflatten__(
inner_tensors: dict[str, torch.Tensor],
metadata: dict[str, Any],
outer_size,
outer_stride,
) -> ScaleTensorSubclass:
return ScaleTensorSubclass(inner_tensors["_x"], inner_tensors["_scale"])

@staticmethod
def from_tensor(x: torch.Tensor) -> ScaleTensorSubclass:
scale = x.abs().max()
return ScaleTensorSubclass(x, scale)

@classmethod
def __torch_dispatch__(cls, aten_ir_op: torch._ops.OpOverload, types, args=(), kwargs=None):

def allowed_subclass(typ):
return (
issubclass(cls, typ)
or issubclass(torch._subclasses.FakeTensor, typ)
or issubclass(torch._subclasses.functional_tensor.FunctionalTensor, typ)
)

def maybe_unwrap_and_scale(t: ScaleTensorSubclass | Any):
if isinstance(t, ScaleTensorSubclass):
if t.is_floating_point():
return t._x * t._scale
else:
return t._x
return t

if not all(allowed_subclass(t) for t in types):
return NotImplementedError(f"Unsupported types are included: {types}")

scales = tuple(t._scale for t in pytree.tree_flatten((args, kwargs))[0] if isinstance(t, ScaleTensorSubclass))
unwrapped_args, unwrapped_kwargs = pytree.tree_map(maybe_unwrap_and_scale, (args, kwargs))
out = aten_ir_op(*unwrapped_args, **unwrapped_kwargs)
return out


@instantiate(
dtypes=(thunder.core.dtypes.float32,),
)
def test_func_of_subclass_ctor_wrapper(executor, device, _):

def f(x: torch.Tensor, scale: torch.Tensor) -> ScaleTensorSubclass:
y = ScaleTensorSubclass(x, scale)
return y

jitted = executor.make_callable(f)

dtype = torch.float32
shape = (2, 2)
x = make_tensor(shape, device=device, dtype=dtype)
scale = make_tensor((), device=device, dtype=dtype)

expected = f(x, scale)
actual = jitted(x, scale)
torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale))

def f(x: torch.Tensor, scale: torch.Tensor):
y = ScaleTensorSubclass(x, scale)
z = ScaleTensorSubclass(y._x, y._scale)
return z

jitted = executor.make_callable(f)

expected = f(x, scale)
actual = jitted(x, scale)
torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale))


@instantiate(
dtypes=(thunder.core.dtypes.float32,),
)
def test_func_calling_converter(executor, device, _):

def f(x: torch.Tensor, scale: torch.Tensor) -> ScaleTensorSubclass:
y = encapsulate_x_and_scale(x, scale)
return y

jitted = executor.make_callable(f)

dtype = torch.float32
shape = (2, 2)

x = make_tensor(shape, device=device, dtype=dtype)
scale = make_tensor((), device=device, dtype=dtype)

expected = f(x, scale)
actual = jitted(x, scale)
torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale))

def g(x: torch.Tensor) -> ScaleTensorSubclass:
y = to_scale_tensor_subclass(x)
return y

jitted = thunder.jit(g)
x = make_tensor(shape, device=device, dtype=dtype)

expected = g(x)
actual = jitted(x)
torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale))


@instantiate(
dtypes=(thunder.core.dtypes.float32,),
decorators=(pytest.mark.parametrize("requires_grad", (False, True), ids=("fwd_only", "with_bwd")),),
)
def test_func_of_subclass_simple_math(executor, device, _, requires_grad):

def f(x: ScaleTensorSubclass, y: ScaleTensorSubclass) -> torch.Tensor:
out = x + y
return out

jitted = executor.make_callable(f)

dtype = torch.float32
shape = (2, 2)
x = ScaleTensorSubclass(
make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad),
make_tensor((), device=device, dtype=dtype),
)
y = ScaleTensorSubclass(
make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad),
make_tensor((), device=device, dtype=dtype),
)

expected = f(x, y)
actual = jitted(x, y)
assert type(expected) is type(actual)
torch.testing.assert_close(expected, actual)
if requires_grad:
actual.mean().backward()

def g(x: ScaleTensorSubclass, data: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
y = EncapsulateXandScale.apply(data, scale)
out = x + y
return out

jitted = executor.make_callable(g)

x = ScaleTensorSubclass(
make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad),
make_tensor((), device=device, dtype=dtype),
)
data = make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad)
scale = make_tensor((), device=device, dtype=dtype)

expected = g(x, data, scale)
actual = jitted(x, data, scale)
assert type(expected) is type(actual)
torch.testing.assert_close(expected, actual)
if requires_grad:
actual.mean().backward()


@instantiate(
dtypes=(thunder.core.dtypes.float32, thunder.core.dtypes.bfloat16),
devicetypes=(thunder.core.devices.DeviceType.CUDA,),
executors=(TorchExecutor, nvFuserExecutor, DynamoThunderExecutor),
decorators=(
pytest.mark.skipif(
not (TORCHAO_AVAILABLE and torch.cuda.get_device_capability() >= (8, 9)),
reason="Requires capability >= 8.9 and torchao",
),
pytest.mark.parametrize("bias", (True, False)),
),
)
def test_torchao_float8_linear(executor, device, dtype, bias):
from torchao.float8 import convert_to_float8_training

batch_size, in_features, out_features = 16, 32, 64
device = torch.device("cuda")
torch_dtype = thunder.core.dtypes.to_torch_dtype(dtype)

model = nn.Sequential(
nn.Linear(in_features, out_features, bias=bias),
nn.GELU(approximate="none"),
nn.Linear(out_features, out_features, bias=bias),
).to(device=device, dtype=torch_dtype)
fp8_model = convert_to_float8_training(model)
x = make_tensor((batch_size, in_features), device=device, dtype=torch_dtype)

expected: torch.Tensor
jitted: nn.Module
backend: ThunderCompiler | None = None

if is_thunderfx := executor == DynamoThunderExecutor:
torch._dynamo.reset()
expected = torch.compile(fp8_model)(x)
backend = ThunderCompiler()
jitted = torch.compile(fp8_model, backend=backend)
else:
expected = fp8_model(x)
jitted = executor.make_callable(fp8_model)

if bias and dtype == thunder.core.dtypes.bfloat16 and executor == nvFuserExecutor:
# ref: https://github.com/NVIDIA/Fuser/issues/4052
with pytest.raises(RuntimeError, match="INTERNAL ASSERT FAILED"):
jitted(x)
return
actual = jitted(x)
if bias and dtype == thunder.core.dtypes.bfloat16 and executor == DynamoThunderExecutor:
with pytest.raises(AssertionError, match="Tensor-likes are not close"):
torch.testing.assert_close(actual, expected)
return

if bias and executor == nvFuserExecutor and dtype == thunder.core.dtypes.bfloat16:
# ref: https://github.com/NVIDIA/Fuser/issues/4052
pass
else:
with torch.no_grad():
grad = torch.ones_like(actual)
if executor == nvFuserExecutor and (
not (not bias and (dtype == thunder.core.dtypes.float32 or dtype == thunder.core.dtypes.bfloat16))
):
if bias and dtype == thunder.core.dtypes.float32:
with pytest.raises(RuntimeError, match="Expected mat1 to be Float8 matrix got Float"):
actual.backward(grad)
else:
with pytest.raises(RuntimeError, match="`b` expected to be column-major but"):
actual.backward(grad)
else:
actual.backward(grad)

if (dtype == thunder.core.dtypes.bfloat16 and executor != DynamoThunderExecutor) or (
not bias and dtype == thunder.core.dtypes.bfloat16 and executor == DynamoThunderExecutor
):
pytest.xfail("numerical error")
torch.testing.assert_close(actual, expected)

# TODO(crcrpar): Think of how to push tensor subclasses to `thunder.jit`.
# Currently no subgraphs go to thunder.jit.
if is_thunderfx:
for subgraph in backend.subgraph_infos:
if not bias and dtype == thunder.core.dtypes.bfloat16:
assert not subgraph.thunder_compiled_fns
else:
assert subgraph.thunder_compiled_fns
1,082 changes: 1,035 additions & 47 deletions thunder/torch/__init__.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions thunder/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@
from .qlora import LORATransform
from .prune_prologue_checks import PrunePrologueChecks
from .extraction_only_prologue_transform import ExtractionOnlyPrologueTransform
from .tensor_wrapper_subclass import unroll_tensor_subclasses


__all__ = [
@@ -11,4 +12,5 @@
"MaterializationTransform",
"PrunePrologueChecks",
"ExtractionOnlyPrologueTransform",
"unroll_tensor_subclasses",
]
1,130 changes: 1,130 additions & 0 deletions thunder/transforms/tensor_wrapper_subclass.py

Large diffs are not rendered by default.