Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nvfuserex] Decomposed torch._scaled_mm #1749

Draft
wants to merge 6 commits into
base: subclass_tensor-type-str
Choose a base branch
from

Conversation

crcrpar
Copy link
Collaborator

@crcrpar crcrpar commented Feb 6, 2025

What does this PR do?

Write a decomposed, emulate nvfuser definition for torch._scaled_mm so that we can be free from worrying about column/row-major of input FP8 matrices of torch._scaled_mm, especially in backward.
The backward (bottom trace) correctly uses nv_decomposed_scaled_mm but the forward, not.
The reason does not look clear to me at the moment.

The decomposed torch._scaled_mm consists of (1) upcasts of 2 FP8 matrices to FP32, (2) scaling of the two matrices, (3) matmul of the two.

# Constructed by Delete Last Used (took 0 milliseconds)
import thunder.core.devices as devices
import thunder.core.dtypes as dtypes
from torch import Tensor
import torch
from torchao.float8.float8_tensor import Float8Tensor
from torchao.float8.float8_tensor import ScaledMMConfig
from torchao.float8.float8_tensor import LinearMMConfig
from torchao.float8.float8_tensor import GemmInputRole
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(input, t_0_bias, weight):
  # input: "cuda:0 f32[16, 32]"
  # t_0_bias: "cuda:0 f32[64]"
  # weight: "cuda:0 f32[64, 32]"
  [scale, t162, t183, t210, t211] = nvFusion0(input, weight)
    # t3 = prims.abs(input)  # t3: "cuda:0 f32[16, 32]"
    # amax = prims.amax(t3, (0, 1))  # amax: "cuda:0 f32[]"
    # t5 = prims.convert_element_type(amax, dtypes.float64)  # t5: "cuda:0 f64[]"
    # t147 = prims.ne(t5, t5)  # t147: "cuda:0 b8[]"
    # t148 = prims.gt(t5, 1e-12)  # t148: "cuda:0 b8[]"
    # t149 = prims.where(t148, t5, 1e-12)  # t149: "cuda:0 f64[]"
    # t10 = prims.where(t147, t5, t149)  # t10: "cuda:0 f64[]"
    # res = prims.div(448.0, t10)  # res: "cuda:0 f64[]"
    # scale = prims.convert_element_type(res, dtypes.float32)  # scale: "cuda:0 f32[]"
    # t153 = prims.broadcast_in_dim(scale, (16, 32), ())  # t153: "cuda:0 f32[16, 32]"
    # t154 = prims.mul(input, t153)  # t154: "cuda:0 f32[16, 32]"
    # t155 = prims.ne(t154, t154)  # t155: "cuda:0 b8[16, 32]"
    # t156 = prims.gt(t154, -448.0)  # t156: "cuda:0 b8[16, 32]"
    # t157 = prims.where(t156, t154, -448.0)  # t157: "cuda:0 f32[16, 32]"
    # t158 = prims.where(t155, t154, t157)  # t158: "cuda:0 f32[16, 32]"
    # t159 = prims.ne(t158, t158)  # t159: "cuda:0 b8[16, 32]"
    # t160 = prims.lt(t158, 448.0)  # t160: "cuda:0 b8[16, 32]"
    # t161 = prims.where(t160, t158, 448.0)  # t161: "cuda:0 f32[16, 32]"
    # t162 = prims.where(t159, t158, t161)  # t162: "cuda:0 f32[16, 32]"
    # t49 = prims.abs(weight)  # t49: "cuda:0 f32[64, 32]"
    # t50 = prims.amax(t49, (0, 1))  # t50: "cuda:0 f32[]"
    # t51 = prims.convert_element_type(t50, dtypes.float64)  # t51: "cuda:0 f64[]"
    # t168 = prims.ne(t51, t51)  # t168: "cuda:0 b8[]"
    # t169 = prims.gt(t51, 1e-12)  # t169: "cuda:0 b8[]"
    # t170 = prims.where(t169, t51, 1e-12)  # t170: "cuda:0 f64[]"
    # t55 = prims.where(t168, t51, t170)  # t55: "cuda:0 f64[]"
    # t56 = prims.div(448.0, t55)  # t56: "cuda:0 f64[]"
    # weight_scale = prims.convert_element_type(t56, dtypes.float32)  # weight_scale: "cuda:0 f32[]"
    # t174 = prims.broadcast_in_dim(weight_scale, (64, 32), ())  # t174: "cuda:0 f32[64, 32]"
    # t175 = prims.mul(weight, t174)  # t175: "cuda:0 f32[64, 32]"
    # t176 = prims.ne(t175, t175)  # t176: "cuda:0 b8[64, 32]"
    # t177 = prims.gt(t175, -448.0)  # t177: "cuda:0 b8[64, 32]"
    # t178 = prims.where(t177, t175, -448.0)  # t178: "cuda:0 f32[64, 32]"
    # t179 = prims.where(t176, t175, t178)  # t179: "cuda:0 f32[64, 32]"
    # t180 = prims.ne(t179, t179)  # t180: "cuda:0 b8[64, 32]"
    # t181 = prims.lt(t179, 448.0)  # t181: "cuda:0 b8[64, 32]"
    # t182 = prims.where(t181, t179, 448.0)  # t182: "cuda:0 f32[64, 32]"
    # t183 = prims.where(t180, t179, t182)  # t183: "cuda:0 f32[64, 32]"
    # t210 = prims.reciprocal(scale)  # t210: "cuda:0 f32[]"
    # t211 = prims.reciprocal(weight_scale)  # t211: "cuda:0 f32[]"

  # /home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/proxies.py:1965:                   self.requires_grad,
  t163 = Tensor.to(t162, copy=False, dtype=torch.float8_e4m3fn)  # t163: "cuda:0 f8_e4m3fn[16, 32]"
    # t163 = ltorch.to(t162, None, None, device=None, dtype=torch.float8_e4m3fn, copy=False, memory_format=None)  # t163: "cuda:0 f8_e4m3fn[16, 32]"
      # t163 = prims.convert_element_type(t162, dtypes.float8_e4m3fn)  # t163: "cuda:0 f8_e4m3fn[16, 32]"
  del t162

  # /home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/proxies.py:1965:                   self.requires_grad,
  t184 = Tensor.to(t183, copy=False, dtype=torch.float8_e4m3fn)  # t184: "cuda:0 f8_e4m3fn[64, 32]"
    # t184 = ltorch.to(t183, None, None, device=None, dtype=torch.float8_e4m3fn, copy=False, memory_format=None)  # t184: "cuda:0 f8_e4m3fn[64, 32]"
      # t184 = prims.convert_element_type(t183, dtypes.float8_e4m3fn)  # t184: "cuda:0 f8_e4m3fn[64, 32]"
  del t183

  # /home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/proxies.py:1965:                   self.requires_grad,
  input_fp8 = Float8Tensor(t163, scale, torch.float32, LinearMMConfig(output=ScaledMMConfig(emulate=False, use_fast_accum=True, fp8_output=False, pad_inner_dim=False), grad_input=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False), grad_weight=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False)), _GemmInputRole_5, None)  # input_fp8: "Float8Tensor[cuda:0 f32[16, 32]] (tensors: _data: cuda:0 f8_e4m3fn[16, 32], _scale: cuda:0 f32[], constants: _orig_dtype: thunder.dtypes.float32, _linear_mm_config: LinearMMConfig(output=ScaledMMConfig(emulate=False, use_fast_accum=True, fp8_output=False, pad_inner_dim=False), grad_input=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False), grad_weight=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False)), _gemm_input_role: GemmInputRole.INPUT, _axiswise_dim: None)"
  del scale
  t204 = torch.reshape(t163, [-1, 32])  # t204: "cuda:0 f8_e4m3fn[16, 32]"
    # t204 = ltorch.reshape(t163, [-1, 32])  # t204: "cuda:0 f8_e4m3fn[16, 32]"
      # t204 = prims.reshape(t163, (16, 32))  # t204: "cuda:0 f8_e4m3fn[16, 32]"
  del t163
  t197 = torch.permute(t184, (1, 0))  # t197: "cuda:0 f8_e4m3fn[32, 64]"
    # t197 = ltorch.permute(t184, (1, 0))  # t197: "cuda:0 f8_e4m3fn[32, 64]"
      # t197 = prims.transpose(t184, (1, 0))  # t197: "cuda:0 f8_e4m3fn[32, 64]"
  del t184
  t207 = torch.transpose(t197, 0, 1)  # t207: "cuda:0 f8_e4m3fn[64, 32]"
    # t207 = ltorch.transpose(t197, 0, 1)  # t207: "cuda:0 f8_e4m3fn[64, 32]"
      # t207 = prims.transpose(t197, (1, 0))  # t207: "cuda:0 f8_e4m3fn[64, 32]"
  del t197
  t208 = torch.clone(t207)  # t208: "cuda:0 f8_e4m3fn[64, 32]"
    # t208 = ltorch.clone(t207, memory_format=_torch_memory_format_7)  # t208: "cuda:0 f8_e4m3fn[64, 32]"
      # t208 = prims.clone(t207)  # t208: "cuda:0 f8_e4m3fn[64, 32]"
  del t207
  t209 = torch.transpose(t208, 0, 1)  # t209: "cuda:0 f8_e4m3fn[32, 64]"
    # t209 = ltorch.transpose(t208, 0, 1)  # t209: "cuda:0 f8_e4m3fn[32, 64]"
      # t209 = prims.transpose(t208, (1, 0))  # t209: "cuda:0 f8_e4m3fn[32, 64]"
  del t208
  t222 = torch.transpose(t209, 0, 1)  # t222: "cuda:0 f8_e4m3fn[64, 32]"
    # t222 = ltorch.transpose(t209, 0, 1)  # t222: "cuda:0 f8_e4m3fn[64, 32]"
      # t222 = prims.transpose(t209, (1, 0))  # t222: "cuda:0 f8_e4m3fn[64, 32]"
  del t209
  t223 = Tensor.contiguous(t222, memory_format=_torch_memory_format_6)  # t223: "cuda:0 f8_e4m3fn[64, 32]"
    # t223 = ltorch.contiguous(t222, memory_format=_torch_memory_format_6)  # t223: "cuda:0 f8_e4m3fn[64, 32]"
      # t223 = prims.stride_order(t222, (1, 0))  # t223: "cuda:0 f8_e4m3fn[64, 32]"
  del t222
  t224 = torch.transpose(t223, 0, 1)  # t224: "cuda:0 f8_e4m3fn[32, 64]"
    # t224 = ltorch.transpose(t223, 0, 1)  # t224: "cuda:0 f8_e4m3fn[32, 64]"
      # t224 = prims.transpose(t223, (1, 0))  # t224: "cuda:0 f8_e4m3fn[32, 64]"
  del t223
  t212 = torch._scaled_mm(t204, t224, t210, t211, None, None, torch.float32, True)  # t212: "cuda:0 f32[16, 64]"
  del t204, t224, t210, t211

  # /home/mkozuki/.pyenv/versions/3.10.13/envs/torchdev-3.10/lib/python3.10/site-packages/torchao/float8/float8_linear.py:106:          return grad_input, grad_weight.t()
  t103 = shallow_copy(t212)  # t103: "cuda:0 f32[16, 64]"
  del t212
  [t143] = nvFusion1(t_0_bias, t103)
    # t190 = prims.broadcast_in_dim(t_0_bias, (16, 64), (1,))  # t190: "cuda:0 f32[16, 64]"
    # t143 = prims.add(t103, t190)  # t143: "cuda:0 f32[16, 64]"
  del t103
  return {'output': (t143,), 'flat_args': [input, t_0_bias, weight], 'flat_output': (t143,)}, ((input_fp8,), ()
# Constructed by Delete Last Used (took 1 milliseconds)
import thunder.core.devices as devices
import thunder.core.dtypes as dtypes
from torch import Tensor
import torch
from torchao.float8.float8_tensor import Float8Tensor
from torchao.float8.float8_tensor import ScaledMMConfig
from torchao.float8.float8_tensor import LinearMMConfig
from torchao.float8.float8_tensor import GemmInputRole
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, _, = saved_for_backward
  clear_mutable_collection(saved_for_backward)
  del saved_for_backward
  t0, = cotangents
  clear_mutable_collection(cotangents)
  del cotangents
  input_fp8, = C0
  clear_mutable_collection(C0)
  del C0
  [t29, t59, t275, t263] = nvFusion0(t0)
    # t29 = prims.sum(t0, (0,))  # t29: "cuda:0 f32[64]"
    # t30 = prims.abs(t0)  # t30: "cuda:0 f32[16, 64]"
    # t43 = prims.amax(t30, (0, 1))  # t43: "cuda:0 f32[]"
    # t44 = prims.convert_element_type(t43, dtypes.float64)  # t44: "cuda:0 f64[]"
    # t45 = prims.ne(t44, t44)  # t45: "cuda:0 b8[]"
    # t46 = prims.gt(t44, 1e-12)  # t46: "cuda:0 b8[]"
    # t47 = prims.where(t46, t44, 1e-12)  # t47: "cuda:0 f64[]"
    # t48 = prims.where(t45, t44, t47)  # t48: "cuda:0 f64[]"
    # t58 = prims.div(57344.0, t48)  # t58: "cuda:0 f64[]"
    # t59 = prims.convert_element_type(t58, dtypes.float32)  # t59: "cuda:0 f32[]"
    # t266 = prims.broadcast_in_dim(t59, (16, 64), ())  # t266: "cuda:0 f32[16, 64]"
    # t267 = prims.mul(t0, t266)  # t267: "cuda:0 f32[16, 64]"
    # t268 = prims.ne(t267, t267)  # t268: "cuda:0 b8[16, 64]"
    # t269 = prims.gt(t267, -57344.0)  # t269: "cuda:0 b8[16, 64]"
    # t270 = prims.where(t269, t267, -57344.0)  # t270: "cuda:0 f32[16, 64]"
    # t271 = prims.where(t268, t267, t270)  # t271: "cuda:0 f32[16, 64]"
    # t272 = prims.ne(t271, t271)  # t272: "cuda:0 b8[16, 64]"
    # t273 = prims.lt(t271, 57344.0)  # t273: "cuda:0 b8[16, 64]"
    # t274 = prims.where(t273, t271, 57344.0)  # t274: "cuda:0 f32[16, 64]"
    # t275 = prims.where(t272, t271, t274)  # t275: "cuda:0 f32[16, 64]"
    # t263 = prims.reciprocal(t59)  # t263: "cuda:0 f32[]"
  del t0
  (t41, t12) = flatten_tensor_subclass(input_fp8)
  del input_fp8
  t276 = Tensor.to(t275, copy=False, dtype=torch.float8_e5m2)  # t276: "cuda:0 f8_e5m2[16, 64]"
    # t276 = ltorch.to(t275, None, None, device=None, dtype=torch.float8_e5m2, copy=False, memory_format=None)  # t276: "cuda:0 f8_e5m2[16, 64]"
      # t276 = prims.convert_element_type(t275, dtypes.float8_e5m2)  # t276: "cuda:0 f8_e5m2[16, 64]"
  del t275
  t252 = torch.reshape(t41, [-1, 32])  # t252: "cuda:0 f8_e4m3fn[16, 32]"
    # t252 = ltorch.reshape(t41, [-1, 32])  # t252: "cuda:0 f8_e4m3fn[16, 32]"
      # t252 = prims.reshape(t41, (16, 32))  # t252: "cuda:0 f8_e4m3fn[16, 32]"
  del t41
  t192 = Float8Tensor(t276, t59, torch.float32, LinearMMConfig(output=ScaledMMConfig(emulate=False, use_fast_accum=True, fp8_output=False, pad_inner_dim=False), grad_input=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False), grad_weight=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False)), _GemmInputRole_4, None)  # t192: "Float8Tensor[cuda:0 f32[16, 64]] (tensors: _data: cuda:0 f8_e5m2[16, 64], _scale: cuda:0 f32[], constants: _orig_dtype: thunder.dtypes.float32, _linear_mm_config: LinearMMConfig(output=ScaledMMConfig(emulate=False, use_fast_accum=True, fp8_output=False, pad_inner_dim=False), grad_input=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False), grad_weight=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False)), _gemm_input_role: GemmInputRole.GRAD_OUTPUT, _axiswise_dim: None)"
  del t276, t59
  t260 = torch.transpose(t252, 0, 1)  # t260: "cuda:0 f8_e4m3fn[32, 16]"
    # t260 = ltorch.transpose(t252, 0, 1)  # t260: "cuda:0 f8_e4m3fn[32, 16]"
      # t260 = prims.transpose(t252, (1, 0))  # t260: "cuda:0 f8_e4m3fn[32, 16]"
  del t252
  (t141, _) = flatten_tensor_subclass(t192)
  del t192
  t261 = torch.clone(t260)  # t261: "cuda:0 f8_e4m3fn[32, 16]"
    # t261 = ltorch.clone(t260, memory_format=_torch_memory_format_5)  # t261: "cuda:0 f8_e4m3fn[32, 16]"
      # t261 = prims.clone(t260)  # t261: "cuda:0 f8_e4m3fn[32, 16]"
  del t260
  t233 = torch.reshape(t141, [-1, 64])  # t233: "cuda:0 f8_e5m2[16, 64]"
    # t233 = ltorch.reshape(t141, [-1, 64])  # t233: "cuda:0 f8_e5m2[16, 64]"
      # t233 = prims.reshape(t141, (16, 64))  # t233: "cuda:0 f8_e5m2[16, 64]"
  del t141
  t262 = torch.transpose(t261, 0, 1)  # t262: "cuda:0 f8_e4m3fn[16, 32]"
    # t262 = ltorch.transpose(t261, 0, 1)  # t262: "cuda:0 f8_e4m3fn[16, 32]"
      # t262 = prims.transpose(t261, (1, 0))  # t262: "cuda:0 f8_e4m3fn[16, 32]"
  del t261
  t257 = torch.permute(t233, (1, 0))  # t257: "cuda:0 f8_e5m2[64, 16]"
    # t257 = ltorch.permute(t233, (1, 0))  # t257: "cuda:0 f8_e5m2[64, 16]"
      # t257 = prims.transpose(t233, (1, 0))  # t257: "cuda:0 f8_e5m2[64, 16]"
  del t233
  [t200] = nvFusion1(t12, t257, t262, t263)
    # t264 = prims.reciprocal(t12)  # t264: "cuda:0 f32[]"
    # t265 = nv_decomposed_scaled_mm(t257, t262, t263, t264, None, None, torch.float32, False)  # t265: "cuda:0 f32[64, 32]"
    # t199 = prims.transpose(t265, (1, 0))  # t199: "cuda:0 f32[32, 64]"
    # t200 = prims.transpose(t199, (1, 0))  # t200: "cuda:0 f32[64, 32]"
  del t12, t257, t262, t263
  return (None, t29, t200)

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
currently backward is failing because a key is missing in `lc_to_nv_map`

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
for an easier comparison between the unrolled trace and the input trace

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
backward now uses nvfuser decomposition but forward mysteriously does not

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant