Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nvfuserex] Decomposed torch._scaled_mm #1749

Draft
wants to merge 6 commits into
base: subclass_tensor-type-str
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2593,3 +2593,94 @@ def scaled_dot_product_flash_attention_grad(
execution_transform=scaled_dot_product_flash_attention,
grad_transform=scaled_dot_product_flash_attention_grad,
)


def _decomposed_scaled_mm_meta(
a: TensorLike,
b: TensorLike,
scale_a: TensorLike,
scale_b: TensorLike,
bias: TensorLike | None = None,
scale_result: TensorLike | None = None,
out_dtype: dtypes.dtype | None = None,
use_fast_accum: bool = False,
) -> TensorLike:
dtype = dtypes.to_dtype(out_dtype) if out_dtype is not None else a.dtype
return TensorProxy(like=a, shape=(a.shape[0], b.shape[1]), device=a.device, dtype=dtype)


def _decomposed_scaled_mm_impl(
a: TensorLike,
b: TensorLike,
scale_a: TensorLike,
scale_b: TensorLike,
bias: TensorLike | None = None,
scale_result: TensorLike | None = None,
out_dtype: dtypes.dtype | None = None,
use_fast_accum: bool = False,
*,
fd: FusionDefinition,
lc_to_nv_map: dict,
) -> TensorLike:
nva = getnv(a, fd, lc_to_nv_map)
nvb = getnv(b, fd, lc_to_nv_map)
nv_scalea = getnv(scale_a, fd, lc_to_nv_map)
nv_scaleb = getnv(scale_b, fd, lc_to_nv_map)
nv_float32 = lcdtype_to_nvdtype(dtypes.float32)

out = fd.ops.matmul(
fd.ops.mul(fd.ops.cast(nva, nv_float32), nv_scalea),
fd.ops.mul(fd.ops.cast(nvb, nv_float32), nv_scaleb),
)
if bias is not None:
out = fd.ops.add(out, getnv(bias, fd, lc_to_nv_map))

dtype = dtypes.to_dtype(out_dtype) if out_dtype is not None else a.dtype
nv_out_dtype = lcdtype_to_nvdtype(dtype)
out = fd.ops.cast(out, nv_out_dtype)

return out


nv_decomposed_scaled_mm = ex.register_operator(
"nv_decomposed_scaled_mm",
meta=_decomposed_scaled_mm_meta,
fn=_decomposed_scaled_mm_impl,
)
register_supported(nv_decomposed_scaled_mm.id, _decomposed_scaled_mm_impl, None)


def _scaled_mm_check(
a: TensorLike,
b: TensorLike,
scale_a: TensorLike,
scale_b: TensorLike,
bias: TensorLike | None = None,
scale_result: TensorLike | None = None,
out_dtype: dtypes.dtype | None = None,
use_fast_accum: bool = False,
) -> bool:
if scale_result is not None or use_fast_accum:
return False
return True


def _scaled_mm_impl(
a: TensorLike,
b: TensorLike,
scale_a: TensorLike,
scale_b: TensorLike,
bias: TensorLike | None = None,
scale_result: TensorLike | None = None,
out_dtype: dtypes.dtype | None = None,
use_fast_accum: bool = False,
) -> bool:
return nv_decomposed_scaled_mm(a, b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum)


for sym_of_scaled_mm in (ltorch._scaled_mm, ltorch.core_aten_scaled_mm):
ex.register_supported(
sym_of_scaled_mm,
checker=_scaled_mm_check,
execution_transform=_scaled_mm_impl,
)
6 changes: 4 additions & 2 deletions thunder/tests/test_tensor_subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,8 @@ def test_torchao_float8_linear(executor, device, dtype, bias):

model = nn.Sequential(
nn.Linear(in_features, out_features, bias=bias),
nn.GELU(approximate="tanh"),
nn.Linear(out_features, out_features, bias=bias),
# nn.GELU(approximate="tanh"),
# nn.Linear(out_features, out_features, bias=bias),
).to(device=device, dtype=torch_dtype)
fp8_model = convert_to_float8_training(model)
x = make_tensor((batch_size, in_features), device=device, dtype=torch_dtype)
Expand Down Expand Up @@ -305,6 +305,8 @@ def test_torchao_float8_linear(executor, device, dtype, bias):
pytest.xfail("numerical error")
torch.testing.assert_close(actual, expected)

actual.mean().backward()

# TODO(crcrpar): Think of how to push tensor subclasses to `thunder.jit`.
# Currently no subgraphs go to thunder.jit.
if is_thunderfx:
Expand Down
Loading
Loading