From fb59e9ec5345289a155e8b6d7fb2779925070869 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Tue, 4 Feb 2025 14:36:11 +0900 Subject: [PATCH 1/6] cyclomatic complexity from 33 to 19 Signed-off-by: Masaki Kozuki --- thunder/tests/test_tensor_subclass.py | 6 +- thunder/transforms/tensor_wrapper_subclass.py | 119 +++++++++++------- 2 files changed, 77 insertions(+), 48 deletions(-) diff --git a/thunder/tests/test_tensor_subclass.py b/thunder/tests/test_tensor_subclass.py index 9abd0c11ad..bc728ff1d4 100644 --- a/thunder/tests/test_tensor_subclass.py +++ b/thunder/tests/test_tensor_subclass.py @@ -268,8 +268,8 @@ def test_torchao_float8_linear(executor, device, dtype, bias): model = nn.Sequential( nn.Linear(in_features, out_features, bias=bias), - nn.GELU(approximate="tanh"), - nn.Linear(out_features, out_features, bias=bias), + # nn.GELU(approximate="tanh"), + # nn.Linear(out_features, out_features, bias=bias), ).to(device=device, dtype=torch_dtype) fp8_model = convert_to_float8_training(model) x = make_tensor((batch_size, in_features), device=device, dtype=torch_dtype) @@ -305,6 +305,8 @@ def test_torchao_float8_linear(executor, device, dtype, bias): pytest.xfail("numerical error") torch.testing.assert_close(actual, expected) + actual.mean().backward() + # TODO(crcrpar): Think of how to push tensor subclasses to `thunder.jit`. # Currently no subgraphs go to thunder.jit. if is_thunderfx: diff --git a/thunder/transforms/tensor_wrapper_subclass.py b/thunder/transforms/tensor_wrapper_subclass.py index 1379a4bd4c..0c7da1156b 100644 --- a/thunder/transforms/tensor_wrapper_subclass.py +++ b/thunder/transforms/tensor_wrapper_subclass.py @@ -225,6 +225,14 @@ def torch_interpreted_func(*args, **kwargs): return torch_trace +def create_ctor_of_tensor_subclass(unflatten_method, tensor_names): + def ctor(tensors, metadata): + inner_tensors = dict(zip(tensor_names, tensors)) + return unflatten_method(inner_tensors, metadata, -1, -1) + + return ctor + + @dataclass class DesugarTensorSubclass: computation_trace: TraceCtx @@ -270,14 +278,7 @@ def _get_tensor_attr_names(self, p: SubclassTensorProxy) -> list[str]: def _get_non_tensor_attr_names(self, p: SubclassTensorProxy) -> list[str]: return p._non_tensor_attr_names - def translate_fx_graph_into_bsym( - self, - bsym: BoundSymbol, - fx_graph: GraphModule, - ) -> BoundSymbol | tuple[BoundSymbol, ...]: - import thunder.torch as ltorch - from thunder.torch import _torch_to_thunder_function_map - + def _unwrap_bsym_args(self, bsym: BoundSymbol) -> tuple[dict[int, ProxyInterface], list[BoundSymbol]]: unwrapped_bsym_args: dict[int, ProxyInterface] = {} list_of_flattening_bsyms: list[BoundSymbol] = [] for a in bsym.flat_args: @@ -311,36 +312,18 @@ def translate_fx_graph_into_bsym( with tracectx(self.computation_trace): a = proxy(a) unwrapped_bsym_args[len(unwrapped_bsym_args)] = a + return unwrapped_bsym_args, list_of_flattening_bsyms - node: Node - list_of_placeholder_node: list[Node] = [] - list_of_function_call_node: list[Node] = [] - node_of_output: Node - for node in fx_graph.graph.nodes: - if node.op == PLACEHOLDER: - list_of_placeholder_node.append(node) - if node.op == CALL_FUNCTION: - list_of_function_call_node.append(node) - if node.op == OUTPUT: - node_of_output = node - args = [n.target for n in list_of_placeholder_node] - arg_name_to_index = {a: i for i, a in enumerate(args)} - ltorch_ops_for_node_of_ops = [] - for node in list_of_function_call_node: - op: OpOverload = node.target - if op not in _torch_to_thunder_function_map: - msg = ( - f"`thunder.torch` does not have corresponding op for {op}. " - "Think about adding it to thunder/torch/default_torch_ops.py" - f"\nThe op is found while flattening the following BoundSymbol:\n{bsym}" - f"\ntorch.fx graph:\n{fx_graph.print_readable(print_output=False)}" - ) - raise RuntimeError(msg) - ltorch_ops_for_node_of_ops.append(_torch_to_thunder_function_map[op]) - + def _evaluate_ltorch_op( + self, + list_of_function_call_node: list[Node], + ltorch_ops_for_node_of_ops: list[Symbol], + unwrapped_bsym_args: dict[int, ProxyInterface], + arg_name_to_index: dict[str, int], + *, + _cur_bsym_for_error_msg: BoundSymbol, + ) -> list[BoundSymbol]: bsyms: list[BoundSymbol] = [] - if list_of_flattening_bsyms: - bsyms.extend(list_of_flattening_bsyms) fxnode_output_name_to_tensor_proxy: dict[str, OpOverload] = {} for node, ltorch_op in zip(list_of_function_call_node, ltorch_ops_for_node_of_ops): args: list[Node] = node.args @@ -369,7 +352,7 @@ def translate_fx_graph_into_bsym( msg = ( f"Failing to map `torch.{node}` to `thunder.torch` op of " f"{ltorch_op} with args of {arg_proxies}\n" - f"BoundSymbol in question is\n```python\n{bsym}\n```\n" + f"BoundSymbol in question is\n```python\n{_cur_bsym_for_error_msg}\n```\n" f"Corresponding torch.fx Graph is\n```python\n{fx_graph.print_readable(print_output=False)}\n```\n" f"Original error is {e}" ) @@ -377,7 +360,55 @@ def translate_fx_graph_into_bsym( else: fxnode_output_name_to_tensor_proxy[str(node)] = out bsyms.extend(self.computation_trace.pop_scope()) - if len(bsyms) == 0: + return bsyms, fxnode_output_name_to_tensor_proxy + + def translate_fx_graph_into_bsym( + self, + bsym: BoundSymbol, + fx_graph: GraphModule, + ) -> BoundSymbol | tuple[BoundSymbol, ...]: + from thunder.torch import _torch_to_thunder_function_map + + unwrapped_bsym_args, list_of_flattening_bsyms = self._unwrap_bsym_args(bsym) + + node: Node + list_of_placeholder_node: list[Node] = [] + list_of_function_call_node: list[Node] = [] + node_of_output: Node + for node in fx_graph.graph.nodes: + if node.op == PLACEHOLDER: + list_of_placeholder_node.append(node) + if node.op == CALL_FUNCTION: + list_of_function_call_node.append(node) + if node.op == OUTPUT: + node_of_output = node + args = [n.target for n in list_of_placeholder_node] + arg_name_to_index = {a: i for i, a in enumerate(args)} + ltorch_ops_for_node_of_ops = [] + for node in list_of_function_call_node: + op: OpOverload = node.target + if op not in _torch_to_thunder_function_map: + msg = ( + f"`thunder.torch` does not have corresponding op for {op}. " + "Think about adding it to thunder/torch/default_torch_ops.py" + f"\nThe op is found while flattening the following BoundSymbol:\n{bsym}" + f"\ntorch.fx graph:\n{fx_graph.print_readable(print_output=False)}" + ) + raise RuntimeError(msg) + ltorch_ops_for_node_of_ops.append(_torch_to_thunder_function_map[op]) + + bsyms: list[BoundSymbol] = [] + if list_of_flattening_bsyms: + bsyms.extend(list_of_flattening_bsyms) + bsyms_of_torch_ops, fxnode_output_name_to_tensor_proxy = self._evaluate_ltorch_op( + list_of_function_call_node, + ltorch_ops_for_node_of_ops, + unwrapped_bsym_args, + arg_name_to_index, + _cur_bsym_for_error_msg=bsym, + ) + bsyms.extend(bsyms_of_torch_ops) + if not bsyms: return [bsym] orig_output = bsym.flat_outs[0] @@ -401,7 +432,7 @@ def translate_fx_graph_into_bsym( if isinstance(value, immutable_dict): new_non_tensor_values.append(dict(value)) else: - new_non_tensor_values.append(list(v)) + new_non_tensor_values.append(list(value)) else: new_non_tensor_values.append(value) utils.check( @@ -444,12 +475,6 @@ def convert_trace_to_fx_graph_and_get_fake_result( self, trace: TraceCtx, ) -> tuple[GraphModule, tuple[OutputWrapperForFxTracing, ...], tuple[torch.Tensor, ...], PyTreeSpec]: - def create_ctor(unflatten_method, tensor_names): - def ctor(tensors, metadata): - inner_tensors = dict(zip(tensor_names, tensors)) - return unflatten_method(inner_tensors, metadata, -1, -1) - - return ctor args = tree_map( lambda t: maybe_materialize_tensor( @@ -467,7 +492,9 @@ def ctor(tensors, metadata): desugared_args.extend([getattr(a, name) for name in attrs]) desugared_args.append(metadta) end_idx = len(desugared_args) - arg_idx_to_sugar[start_idx] = end_idx, create_ctor(type(a).__tensor_unflatten__, attrs) + arg_idx_to_sugar[start_idx] = end_idx, create_ctor_of_tensor_subclass( + type(a).__tensor_unflatten__, attrs + ) else: desugared_args.append(a) From 6ac70f45e86b31c44dda5094cd5b1f09c2f8a6f4 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Tue, 4 Feb 2025 14:46:35 +0900 Subject: [PATCH 2/6] complexity relaxed Signed-off-by: Masaki Kozuki --- thunder/transforms/tensor_wrapper_subclass.py | 102 ++++++++++-------- 1 file changed, 59 insertions(+), 43 deletions(-) diff --git a/thunder/transforms/tensor_wrapper_subclass.py b/thunder/transforms/tensor_wrapper_subclass.py index 0c7da1156b..b51e2ac9fe 100644 --- a/thunder/transforms/tensor_wrapper_subclass.py +++ b/thunder/transforms/tensor_wrapper_subclass.py @@ -362,6 +362,56 @@ def _evaluate_ltorch_op( bsyms.extend(self.computation_trace.pop_scope()) return bsyms, fxnode_output_name_to_tensor_proxy + def _postprocess_subclass_from_torch_op( + self, + orig_output: SubclassTensorProxy, + node_of_output: Node, + arg_name_to_index: dict[str, int], + unwrapped_bsym_args: dict[int, ProxyInterface], + fxnode_output_name_to_tensor_proxy: dict[str, OpOverload], + ) -> tuple[BoundSymbol, SubclassTensorProxy]: + # note(crcrpar): args[0] would be list of tensors, and args[1] could be list of non-tensors. + args: list[Node] = node_of_output.args[0] + new_tensor_proxies = [] + new_non_tensor_values = [] + for a in args: + value = a + if isinstance(a, Node): + if isinstance(a.target, str): + value = unwrapped_bsym_args[arg_name_to_index[a.target]] + else: + value = fxnode_output_name_to_tensor_proxy[str(a)] + if isinstance(value, TensorProxy): + new_tensor_proxies.append(value) + elif isinstance(value, (immutable_dict, immutable_list)): + if isinstance(value, immutable_dict): + new_non_tensor_values.append(dict(value)) + else: + new_non_tensor_values.append(list(value)) + else: + new_non_tensor_values.append(value) + utils.check( + len(orig_output._tensors) == len(new_tensor_proxies), + lambda: ( + f"The number of new tensor proxies for {orig_output=} does not match: " + f"{len(new_tensor_proxies)=} != {len(orig_output._tensors)=}" + ), + ) + with tracectx(self.computation_trace): + new_subclass = orig_output.replace() + new_subclass._tensors = new_tensor_proxies + for name, value in zip(new_subclass._tensor_attr_names, new_tensor_proxies): + setattr(new_subclass, name, value) + return ( + prims.unflatten_tensor_subclass.bind( + new_subclass._subclass_type, + dict(zip(new_subclass._tensor_attr_names, new_tensor_proxies)), + dict(zip(new_subclass._non_tensor_attr_names, new_subclass._non_tensors)), + output=new_subclass, + ), + new_subclass, + ) + def translate_fx_graph_into_bsym( self, bsym: BoundSymbol, @@ -415,56 +465,22 @@ def translate_fx_graph_into_bsym( if is_subclass_ctor_bsym := bsym.sym.id == prims.PrimIDs.TENSOR_SUBCLASS_CTOR: utils.check_type(orig_output, SubclassTensorProxy) if isinstance(orig_output, SubclassTensorProxy): - # note(crcrpar): args[0] would be list of tensors, and args[1] could be list of non-tensors. - args: list[Node] = node_of_output.args[0] - new_tensor_proxies = [] - new_non_tensor_values = [] - for a in args: - value = a - if isinstance(a, Node): - if isinstance(a.target, str): - value = unwrapped_bsym_args[arg_name_to_index[a.target]] - else: - value = fxnode_output_name_to_tensor_proxy[str(a)] - if isinstance(value, TensorProxy): - new_tensor_proxies.append(value) - elif isinstance(value, (immutable_dict, immutable_list)): - if isinstance(value, immutable_dict): - new_non_tensor_values.append(dict(value)) - else: - new_non_tensor_values.append(list(value)) - else: - new_non_tensor_values.append(value) - utils.check( - len(orig_output._tensors) == len(new_tensor_proxies), - lambda: ( - f"The number of new tensor proxies for {orig_output=} does not match: " - f"{len(new_tensor_proxies)=} != {len(orig_output._tensors)=}" - ), - ) - with tracectx(self.computation_trace): - new_subclass = orig_output.replace() - new_subclass._tensors = new_tensor_proxies - for name, value in zip(new_subclass._tensor_attr_names, new_tensor_proxies): - setattr(new_subclass, name, value) - bsyms.append( - prims.unflatten_tensor_subclass.bind( - new_subclass._subclass_type, - dict(zip(new_subclass._tensor_attr_names, new_tensor_proxies)), - dict(zip(new_subclass._non_tensor_attr_names, new_subclass._non_tensors)), - output=new_subclass, - ) + unflatten_bsym, new_subclass_proxy = self._postprocess_subclass_from_torch_op( + orig_output, + node_of_output, + arg_name_to_index, + unwrapped_bsym_args, + fxnode_output_name_to_tensor_proxy, ) - - self.swap_map[variableify(orig_output)] = new_subclass - self.subclass_proxy_to_flatten.add(variableify(new_subclass)) + bsyms.append(unflatten_bsym) + self.swap_map[variableify(orig_output)] = new_subclass_proxy + self.subclass_proxy_to_flatten.add(variableify(new_subclass_proxy)) else: non_none_args = [n for n in node_of_output.args[0] if n is not None] utils.check(len(non_none_args) == 1, lambda: f"{node_of_output.args = }") new_out_node = non_none_args[0] self.swap_map[variableify(orig_output)] = fxnode_output_name_to_tensor_proxy[str(new_out_node)] - args = ", ".join([t.name if isinstance(t, ProxyInterface) else f"{t}" for t in bsym.flat_args]) header = f"{bsym.sym.id}({args})" for i, sbsym in enumerate(bsyms, 1): From f09c32fe1c4fbad60794ac303fc1dca19b6ef24d Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Wed, 5 Feb 2025 18:26:49 +0900 Subject: [PATCH 3/6] supply missing fx graph Signed-off-by: Masaki Kozuki --- thunder/transforms/tensor_wrapper_subclass.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/thunder/transforms/tensor_wrapper_subclass.py b/thunder/transforms/tensor_wrapper_subclass.py index b51e2ac9fe..7079c89dfd 100644 --- a/thunder/transforms/tensor_wrapper_subclass.py +++ b/thunder/transforms/tensor_wrapper_subclass.py @@ -322,6 +322,7 @@ def _evaluate_ltorch_op( arg_name_to_index: dict[str, int], *, _cur_bsym_for_error_msg: BoundSymbol, + _cur_fxgraph_for_error_msg: GraphModule, ) -> list[BoundSymbol]: bsyms: list[BoundSymbol] = [] fxnode_output_name_to_tensor_proxy: dict[str, OpOverload] = {} @@ -353,7 +354,7 @@ def _evaluate_ltorch_op( f"Failing to map `torch.{node}` to `thunder.torch` op of " f"{ltorch_op} with args of {arg_proxies}\n" f"BoundSymbol in question is\n```python\n{_cur_bsym_for_error_msg}\n```\n" - f"Corresponding torch.fx Graph is\n```python\n{fx_graph.print_readable(print_output=False)}\n```\n" + f"Corresponding torch.fx Graph is\n```python\n{_cur_fxgraph_for_error_msg.print_readable(print_output=False)}\n```\n" f"Original error is {e}" ) raise type(e)(msg) @@ -456,6 +457,7 @@ def translate_fx_graph_into_bsym( unwrapped_bsym_args, arg_name_to_index, _cur_bsym_for_error_msg=bsym, + _cur_fxgraph_for_error_msg=fx_graph, ) bsyms.extend(bsyms_of_torch_ops) if not bsyms: From 4724b1e2d4523a1c8a951308479ea644a59d9887 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Wed, 5 Feb 2025 13:07:52 +0000 Subject: [PATCH 4/6] add nvfuser decomposed scaled mm currently backward is failing because a key is missing in `lc_to_nv_map` Signed-off-by: Masaki Kozuki --- thunder/executors/nvfuserex_impl.py | 86 +++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index 9abab32c79..896679f230 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -2593,3 +2593,89 @@ def scaled_dot_product_flash_attention_grad( execution_transform=scaled_dot_product_flash_attention, grad_transform=scaled_dot_product_flash_attention_grad, ) + + +def _decomposed_scaled_mm_meta( + a: TensorLike, + b: TensorLike, + scale_a: TensorLike, + scale_b: TensorLike, + bias: TensorLike | None = None, + scale_result: TensorLike | None = None, + out_dtype: dtypes.dtype | None = None, + use_fast_accum: bool = False, +) -> TensorLike: + return TensorProxy(shape=(a.shape[0], b.shape[1]), device=a.device, dtype=out_dtype or a.dtype) + + +def _decomposed_scaled_mm_impl( + a: TensorLike, + b: TensorLike, + scale_a: TensorLike, + scale_b: TensorLike, + bias: TensorLike | None = None, + scale_result: TensorLike | None = None, + out_dtype: dtypes.dtype | None = None, + use_fast_accum: bool = False, + *, + fd: FusionDefinition, + lc_to_nv_map: dict, +) -> TensorLike: + nva = getnv(a, fd, lc_to_nv_map) + nvb = getnv(b, fd, lc_to_nv_map) + nv_scalea = getnv(scale_a, fd, lc_to_nv_map) + nv_scaleb = getnv(scale_b, fd, lc_to_nv_map) + nv_float32 = lcdtype_to_nvdtype(dtypes.float32) + + out = fd.ops.matmul( + fd.ops.mul(fd.ops.cast(nva, nv_float32), nv_scalea), + fd.ops.mul(fd.ops.cast(nvb, nv_float32), nv_scaleb), + ) + if bias is not None: + out = fd.ops.add(out, getnv(bias, fd, lc_to_nv_map)) + + if out_dtype is not None: + nv_out_dtype = lcdtype_to_nvdtype(out_dtype) + out = fd.ops.cast(out, nv_out_dtype) + + return out + + +nv_decomposed_scaled_mm = ex.register_operator( + "nv_decomposed_scaled_mm", + meta=_decomposed_scaled_mm_meta, + fn=_decomposed_scaled_mm_impl, +) +register_supported(nv_decomposed_scaled_mm.id, _decomposed_scaled_mm_impl, None) + + +def _scaled_mm_check( + a: TensorLike, + b: TensorLike, + scale_a: TensorLike, + scale_b: TensorLike, + bias: TensorLike | None = None, + scale_result: TensorLike | None = None, + out_dtype: dtypes.dtype | None = None, + use_fast_accum: bool = False, +) -> bool: + if scale_result is not None or use_fast_accum: + return False + return True + + +def _scaled_mm_impl( + a: TensorLike, + b: TensorLike, + scale_a: TensorLike, + scale_b: TensorLike, + bias: TensorLike | None = None, + scale_result: TensorLike | None = None, + out_dtype: dtypes.dtype | None = None, + use_fast_accum: bool = False, +) -> bool: + return nv_decomposed_scaled_mm(a, b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum) + + +register_supported(ltorch._scaled_mm, _scaled_mm_impl, _scaled_mm_check) +register_supported(ltorch.core_aten_scaled_mm, _scaled_mm_impl, _scaled_mm_check) From 33b45dab16f6180f3656fe2aea08c248c3da36d7 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Thu, 6 Feb 2025 20:48:41 +0900 Subject: [PATCH 5/6] use original signature for an easier comparison between the unrolled trace and the input trace Signed-off-by: Masaki Kozuki --- thunder/transforms/tensor_wrapper_subclass.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/thunder/transforms/tensor_wrapper_subclass.py b/thunder/transforms/tensor_wrapper_subclass.py index 7079c89dfd..7916beb010 100644 --- a/thunder/transforms/tensor_wrapper_subclass.py +++ b/thunder/transforms/tensor_wrapper_subclass.py @@ -417,6 +417,7 @@ def translate_fx_graph_into_bsym( self, bsym: BoundSymbol, fx_graph: GraphModule, + header: str, ) -> BoundSymbol | tuple[BoundSymbol, ...]: from thunder.torch import _torch_to_thunder_function_map @@ -483,8 +484,6 @@ def translate_fx_graph_into_bsym( utils.check(len(non_none_args) == 1, lambda: f"{node_of_output.args = }") new_out_node = non_none_args[0] self.swap_map[variableify(orig_output)] = fxnode_output_name_to_tensor_proxy[str(new_out_node)] - args = ", ".join([t.name if isinstance(t, ProxyInterface) else f"{t}" for t in bsym.flat_args]) - header = f"{bsym.sym.id}({args})" for i, sbsym in enumerate(bsyms, 1): sbsym.header = f"[{i}/{len(bsyms)}] unrolled `__torch_dispatch__` of `{header}`" return bsyms @@ -666,7 +665,9 @@ def __call__(self, bsym: BoundSymbol) -> list[BoundSymbol]: bsym_with_modified_output = updated_bsym.from_bsym_swap_proxies(self.swap_map) self.bsym_to_new_outputs[bsym_with_modified_output] = bsym_with_modified_output - return self.translate_fx_graph_into_bsym(bsym_with_modified_output, fx) + args = ", ".join([t.name if isinstance(t, ProxyInterface) else f"{t}" for t in bsym.flat_args]) + header = f"{bsym.sym.id}({args})" + return self.translate_fx_graph_into_bsym(bsym_with_modified_output, fx, header=header) def tensor_subclass_dce(trace: TraceCtx) -> TraceCtx: From 13ae807a4c329bedc1de420e95bf661c0b5789fc Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Thu, 6 Feb 2025 22:24:59 +0900 Subject: [PATCH 6/6] try `ex.register_supported` backward now uses nvfuser decomposition but forward mysteriously does not Signed-off-by: Masaki Kozuki --- thunder/executors/nvfuserex_impl.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index 896679f230..19068c71e9 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -2605,7 +2605,8 @@ def _decomposed_scaled_mm_meta( out_dtype: dtypes.dtype | None = None, use_fast_accum: bool = False, ) -> TensorLike: - return TensorProxy(shape=(a.shape[0], b.shape[1]), device=a.device, dtype=out_dtype or a.dtype) + dtype = dtypes.to_dtype(out_dtype) if out_dtype is not None else a.dtype + return TensorProxy(like=a, shape=(a.shape[0], b.shape[1]), device=a.device, dtype=dtype) def _decomposed_scaled_mm_impl( @@ -2634,9 +2635,9 @@ def _decomposed_scaled_mm_impl( if bias is not None: out = fd.ops.add(out, getnv(bias, fd, lc_to_nv_map)) - if out_dtype is not None: - nv_out_dtype = lcdtype_to_nvdtype(out_dtype) - out = fd.ops.cast(out, nv_out_dtype) + dtype = dtypes.to_dtype(out_dtype) if out_dtype is not None else a.dtype + nv_out_dtype = lcdtype_to_nvdtype(dtype) + out = fd.ops.cast(out, nv_out_dtype) return out @@ -2677,5 +2678,9 @@ def _scaled_mm_impl( return nv_decomposed_scaled_mm(a, b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum) -register_supported(ltorch._scaled_mm, _scaled_mm_impl, _scaled_mm_check) -register_supported(ltorch.core_aten_scaled_mm, _scaled_mm_impl, _scaled_mm_check) +for sym_of_scaled_mm in (ltorch._scaled_mm, ltorch.core_aten_scaled_mm): + ex.register_supported( + sym_of_scaled_mm, + checker=_scaled_mm_check, + execution_transform=_scaled_mm_impl, + )