-
Notifications
You must be signed in to change notification settings - Fork 98
Trace Transform for Tensor Wrapper Subclasses #1883
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
398ae2a
0bf415b
3b7c0fc
9a07d68
00e9d01
6591ef6
2cd32de
c5b7fe6
8bcb2fd
b506502
c6e58ff
beb2f7d
5d83cf1
348c531
c170989
97e03b0
1aa342d
d20f4e9
3de2ce2
40689ef
498b4f1
c8d6961
2a0abdb
04c51ca
962d4a8
b9657dc
abde3e3
b9c15c0
c05e916
7d05d76
8ec75c4
224fd34
e62234e
4b10385
bdd4374
7c02614
186a298
5283902
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,3 +8,4 @@ thunder.transforms | |
|
||
MaterializationTransform | ||
ConstantFolding | ||
unroll_tensor_subclasses |
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -178,9 +178,11 @@ def set_provenance(self, provenance: TraceProvenance) -> None: | |
# Methods related to name construction | ||
# | ||
|
||
def add_name(self, name: str) -> None: | ||
def add_name(self, name: str, *, prefix: str | None = None) -> None: | ||
from thunder.core.proxies import PREFIXES_ALLOW_NAME_DUPLICATES | ||
|
||
baseutils.check( | ||
name not in self.names, | ||
name not in self.names or (prefix is not None and prefix in PREFIXES_ALLOW_NAME_DUPLICATES), | ||
Comment on lines
+181
to
+185
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If a program calls a custom |
||
lambda: f"Trying to add the name {name} to a trace, but that name is already used", | ||
) | ||
self.names.add(name) | ||
|
@@ -221,7 +223,7 @@ def _make_name(self, *, prefix: str | None = None, is_object_name: bool = False, | |
# just records the given name | ||
def make_name(self, name: str | None = None, *, prefix: str | None = None) -> str: | ||
if name is not None: | ||
self.add_name(name) | ||
self.add_name(name, prefix=prefix) | ||
return name | ||
|
||
return self._make_name(prefix=prefix) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -131,8 +131,11 @@ def add_to_swap_map(old, new): | |
old = old.replace(shape=new._shape) | ||
|
||
if isinstance(new, VJPDual): | ||
swap_map[variableify(new.primal)] = old | ||
new.primal = old | ||
# note(crcrpar): Without this sanity check, `subclass.__tensor_flatten__`, | ||
# seems to cause `new.primal` == `old`, leading to a cycle in swapping. | ||
if (key := variableify(new.primal)) != variableify(old): | ||
Comment on lines
+134
to
+136
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I need to revisit this comment |
||
swap_map[variableify(new.primal)] = old | ||
new.primal = old | ||
else: | ||
assert isinstance(new, ProxyInterface), (old, new) | ||
swap_map[variableify(new)] = old | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,330 @@ | ||
from __future__ import annotations | ||
from typing import TYPE_CHECKING | ||
|
||
from lightning_utilities.core.imports import package_available | ||
import pytest | ||
import torch | ||
import torch.nn as nn | ||
from torch.utils import _pytree as pytree | ||
|
||
import thunder | ||
from thunder.dynamo.compiler import ThunderCompiler | ||
from thunder.tests.framework import ( | ||
DynamoThunderExecutor, | ||
TorchExecutor, | ||
instantiate, | ||
nvFuserExecutor, | ||
) | ||
from thunder.tests.make_tensor import make_tensor | ||
|
||
if TYPE_CHECKING: | ||
from typing import Any | ||
|
||
|
||
TORCHAO_AVAILABLE = package_available("torchao") | ||
|
||
|
||
@torch._dynamo.allow_in_graph | ||
class EncapsulateXandScale(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, x: torch.Tensor, scale: torch.Tensor): | ||
return ScaleTensorSubclass(x, scale) | ||
|
||
@staticmethod | ||
def backward(ctx, grad): | ||
return grad, None | ||
|
||
|
||
def encapsulate_x_and_scale(x, scale) -> ScaleTensorSubclass: | ||
return EncapsulateXandScale.apply(x, scale) | ||
|
||
|
||
@torch._dynamo.allow_in_graph | ||
class ToScaleTensorSubclass(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, x: torch.Tensor): | ||
return ScaleTensorSubclass.from_tensor(x) | ||
|
||
@staticmethod | ||
def backward(ctx, grad): | ||
return grad | ||
|
||
|
||
def to_scale_tensor_subclass(x: torch.Tensor) -> ScaleTensorSubclass: | ||
return ToScaleTensorSubclass.apply(x) | ||
|
||
|
||
class ScaleTensorSubclass(torch.Tensor): | ||
_x: torch.Tensor | ||
_scale: torch.Tensor | ||
__slots__ = ["_x", "_scale"] | ||
|
||
def __new__(cls, x: torch.Tensor, scale: torch.Tensor): | ||
assert scale.numel() == 1, f"Invalid `scale`: {scale}" | ||
dtype = x.dtype | ||
device = x.device | ||
self = torch.Tensor._make_wrapper_subclass( | ||
cls, | ||
x.size(), | ||
dtype=dtype, | ||
device=device, | ||
# strides=x.stride(), | ||
# storage_offset=x.storage_offset(), | ||
# layout=x.layout, | ||
requires_grad=x.requires_grad, | ||
) | ||
self._x = x | ||
self._scale = scale | ||
|
||
return self | ||
|
||
# ref: https://github.com/albanD/subclass_zoo/blob/ec47458/base_tensor.py#L22 | ||
__torch_function__ = torch._C._disabled_torch_function_impl | ||
|
||
def __repr__(self): | ||
return f"ScaleTensorSubclass(dtype={self._x.dtype}, device={self._x.device}, x={self._x}, scale={self._scale})" | ||
|
||
def __tensor_flatten__(self) -> tuple[list[str], dict[str, Any]]: | ||
return ["_x", "_scale"], {} | ||
|
||
@staticmethod | ||
def __tensor_unflatten__( | ||
inner_tensors: dict[str, torch.Tensor], | ||
metadata: dict[str, Any], | ||
outer_size, | ||
outer_stride, | ||
) -> ScaleTensorSubclass: | ||
return ScaleTensorSubclass(inner_tensors["_x"], inner_tensors["_scale"]) | ||
|
||
@staticmethod | ||
def from_tensor(x: torch.Tensor) -> ScaleTensorSubclass: | ||
scale = x.abs().max() | ||
return ScaleTensorSubclass(x, scale) | ||
|
||
@classmethod | ||
def __torch_dispatch__(cls, aten_ir_op: torch._ops.OpOverload, types, args=(), kwargs=None): | ||
|
||
def allowed_subclass(typ): | ||
return ( | ||
issubclass(cls, typ) | ||
or issubclass(torch._subclasses.FakeTensor, typ) | ||
or issubclass(torch._subclasses.functional_tensor.FunctionalTensor, typ) | ||
) | ||
|
||
def maybe_unwrap_and_scale(t: ScaleTensorSubclass | Any): | ||
if isinstance(t, ScaleTensorSubclass): | ||
if t.is_floating_point(): | ||
return t._x * t._scale | ||
else: | ||
return t._x | ||
return t | ||
|
||
if not all(allowed_subclass(t) for t in types): | ||
return NotImplementedError(f"Unsupported types are included: {types}") | ||
|
||
scales = tuple(t._scale for t in pytree.tree_flatten((args, kwargs))[0] if isinstance(t, ScaleTensorSubclass)) | ||
unwrapped_args, unwrapped_kwargs = pytree.tree_map(maybe_unwrap_and_scale, (args, kwargs)) | ||
out = aten_ir_op(*unwrapped_args, **unwrapped_kwargs) | ||
return out | ||
|
||
|
||
@instantiate( | ||
dtypes=(thunder.core.dtypes.float32,), | ||
) | ||
def test_func_of_subclass_ctor_wrapper(executor, device, _): | ||
|
||
def f(x: torch.Tensor, scale: torch.Tensor) -> ScaleTensorSubclass: | ||
y = ScaleTensorSubclass(x, scale) | ||
return y | ||
|
||
jitted = executor.make_callable(f) | ||
|
||
dtype = torch.float32 | ||
shape = (2, 2) | ||
x = make_tensor(shape, device=device, dtype=dtype) | ||
scale = make_tensor((), device=device, dtype=dtype) | ||
|
||
expected = f(x, scale) | ||
actual = jitted(x, scale) | ||
torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale)) | ||
|
||
def f(x: torch.Tensor, scale: torch.Tensor): | ||
y = ScaleTensorSubclass(x, scale) | ||
z = ScaleTensorSubclass(y._x, y._scale) | ||
return z | ||
|
||
jitted = executor.make_callable(f) | ||
|
||
expected = f(x, scale) | ||
actual = jitted(x, scale) | ||
torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale)) | ||
|
||
|
||
@instantiate( | ||
dtypes=(thunder.core.dtypes.float32,), | ||
) | ||
def test_func_calling_converter(executor, device, _): | ||
|
||
def f(x: torch.Tensor, scale: torch.Tensor) -> ScaleTensorSubclass: | ||
y = encapsulate_x_and_scale(x, scale) | ||
return y | ||
|
||
jitted = executor.make_callable(f) | ||
|
||
dtype = torch.float32 | ||
shape = (2, 2) | ||
|
||
x = make_tensor(shape, device=device, dtype=dtype) | ||
scale = make_tensor((), device=device, dtype=dtype) | ||
|
||
expected = f(x, scale) | ||
actual = jitted(x, scale) | ||
torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale)) | ||
|
||
def g(x: torch.Tensor) -> ScaleTensorSubclass: | ||
y = to_scale_tensor_subclass(x) | ||
return y | ||
|
||
jitted = thunder.jit(g) | ||
x = make_tensor(shape, device=device, dtype=dtype) | ||
|
||
expected = g(x) | ||
actual = jitted(x) | ||
torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale)) | ||
|
||
|
||
@instantiate( | ||
dtypes=(thunder.core.dtypes.float32,), | ||
decorators=(pytest.mark.parametrize("requires_grad", (False, True), ids=("fwd_only", "with_bwd")),), | ||
) | ||
def test_func_of_subclass_simple_math(executor, device, _, requires_grad): | ||
|
||
def f(x: ScaleTensorSubclass, y: ScaleTensorSubclass) -> torch.Tensor: | ||
out = x + y | ||
return out | ||
|
||
jitted = executor.make_callable(f) | ||
|
||
dtype = torch.float32 | ||
shape = (2, 2) | ||
x = ScaleTensorSubclass( | ||
make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad), | ||
make_tensor((), device=device, dtype=dtype), | ||
) | ||
y = ScaleTensorSubclass( | ||
make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad), | ||
make_tensor((), device=device, dtype=dtype), | ||
) | ||
|
||
expected = f(x, y) | ||
actual = jitted(x, y) | ||
assert type(expected) is type(actual) | ||
torch.testing.assert_close(expected, actual) | ||
if requires_grad: | ||
actual.mean().backward() | ||
|
||
def g(x: ScaleTensorSubclass, data: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: | ||
y = EncapsulateXandScale.apply(data, scale) | ||
out = x + y | ||
return out | ||
|
||
jitted = executor.make_callable(g) | ||
|
||
x = ScaleTensorSubclass( | ||
make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad), | ||
make_tensor((), device=device, dtype=dtype), | ||
) | ||
data = make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad) | ||
scale = make_tensor((), device=device, dtype=dtype) | ||
|
||
expected = g(x, data, scale) | ||
actual = jitted(x, data, scale) | ||
assert type(expected) is type(actual) | ||
torch.testing.assert_close(expected, actual) | ||
if requires_grad: | ||
actual.mean().backward() | ||
|
||
|
||
@instantiate( | ||
dtypes=(thunder.core.dtypes.float32, thunder.core.dtypes.bfloat16), | ||
devicetypes=(thunder.core.devices.DeviceType.CUDA,), | ||
executors=(TorchExecutor, nvFuserExecutor, DynamoThunderExecutor), | ||
decorators=( | ||
pytest.mark.skipif( | ||
not (TORCHAO_AVAILABLE and torch.cuda.get_device_capability() >= (8, 9)), | ||
reason="Requires capability >= 8.9 and torchao", | ||
), | ||
pytest.mark.parametrize("bias", (True, False)), | ||
), | ||
) | ||
def test_torchao_float8_linear(executor, device, dtype, bias): | ||
from torchao.float8 import convert_to_float8_training | ||
|
||
batch_size, in_features, out_features = 16, 32, 64 | ||
device = torch.device("cuda") | ||
torch_dtype = thunder.core.dtypes.to_torch_dtype(dtype) | ||
|
||
model = nn.Sequential( | ||
nn.Linear(in_features, out_features, bias=bias), | ||
nn.GELU(approximate="none"), | ||
nn.Linear(out_features, out_features, bias=bias), | ||
).to(device=device, dtype=torch_dtype) | ||
fp8_model = convert_to_float8_training(model) | ||
x = make_tensor((batch_size, in_features), device=device, dtype=torch_dtype) | ||
|
||
expected: torch.Tensor | ||
jitted: nn.Module | ||
backend: ThunderCompiler | None = None | ||
|
||
if is_thunderfx := executor == DynamoThunderExecutor: | ||
torch._dynamo.reset() | ||
expected = torch.compile(fp8_model)(x) | ||
backend = ThunderCompiler() | ||
jitted = torch.compile(fp8_model, backend=backend) | ||
else: | ||
expected = fp8_model(x) | ||
jitted = executor.make_callable(fp8_model) | ||
|
||
if bias and dtype == thunder.core.dtypes.bfloat16 and executor == nvFuserExecutor: | ||
# ref: https://github.com/NVIDIA/Fuser/issues/4052 | ||
with pytest.raises(RuntimeError, match="INTERNAL ASSERT FAILED"): | ||
jitted(x) | ||
return | ||
actual = jitted(x) | ||
if bias and dtype == thunder.core.dtypes.bfloat16 and executor == DynamoThunderExecutor: | ||
with pytest.raises(AssertionError, match="Tensor-likes are not close"): | ||
torch.testing.assert_close(actual, expected) | ||
return | ||
|
||
if bias and executor == nvFuserExecutor and dtype == thunder.core.dtypes.bfloat16: | ||
# ref: https://github.com/NVIDIA/Fuser/issues/4052 | ||
pass | ||
else: | ||
with torch.no_grad(): | ||
grad = torch.ones_like(actual) | ||
if executor == nvFuserExecutor and ( | ||
not (not bias and (dtype == thunder.core.dtypes.float32 or dtype == thunder.core.dtypes.bfloat16)) | ||
): | ||
if bias and dtype == thunder.core.dtypes.float32: | ||
with pytest.raises(RuntimeError, match="Expected mat1 to be Float8 matrix got Float"): | ||
actual.backward(grad) | ||
else: | ||
with pytest.raises(RuntimeError, match="`b` expected to be column-major but"): | ||
actual.backward(grad) | ||
else: | ||
actual.backward(grad) | ||
|
||
if (dtype == thunder.core.dtypes.bfloat16 and executor != DynamoThunderExecutor) or ( | ||
not bias and dtype == thunder.core.dtypes.bfloat16 and executor == DynamoThunderExecutor | ||
): | ||
pytest.xfail("numerical error") | ||
torch.testing.assert_close(actual, expected) | ||
|
||
# TODO(crcrpar): Think of how to push tensor subclasses to `thunder.jit`. | ||
# Currently no subgraphs go to thunder.jit. | ||
if is_thunderfx: | ||
for subgraph in backend.subgraph_infos: | ||
if not bias and dtype == thunder.core.dtypes.bfloat16: | ||
assert not subgraph.thunder_compiled_fns | ||
else: | ||
assert subgraph.thunder_compiled_fns |
Large diffs are not rendered by default.
Large diffs are not rendered by default.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be strict and precise, I want to include torchex only when prologue takes
SubclassTensorProxy
s because torchex is only here fortensor_subclass.__tensor_flatten__()
used inside prologue.