diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index 9abab32c79..19068c71e9 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -2593,3 +2593,94 @@ def scaled_dot_product_flash_attention_grad( execution_transform=scaled_dot_product_flash_attention, grad_transform=scaled_dot_product_flash_attention_grad, ) + + +def _decomposed_scaled_mm_meta( + a: TensorLike, + b: TensorLike, + scale_a: TensorLike, + scale_b: TensorLike, + bias: TensorLike | None = None, + scale_result: TensorLike | None = None, + out_dtype: dtypes.dtype | None = None, + use_fast_accum: bool = False, +) -> TensorLike: + dtype = dtypes.to_dtype(out_dtype) if out_dtype is not None else a.dtype + return TensorProxy(like=a, shape=(a.shape[0], b.shape[1]), device=a.device, dtype=dtype) + + +def _decomposed_scaled_mm_impl( + a: TensorLike, + b: TensorLike, + scale_a: TensorLike, + scale_b: TensorLike, + bias: TensorLike | None = None, + scale_result: TensorLike | None = None, + out_dtype: dtypes.dtype | None = None, + use_fast_accum: bool = False, + *, + fd: FusionDefinition, + lc_to_nv_map: dict, +) -> TensorLike: + nva = getnv(a, fd, lc_to_nv_map) + nvb = getnv(b, fd, lc_to_nv_map) + nv_scalea = getnv(scale_a, fd, lc_to_nv_map) + nv_scaleb = getnv(scale_b, fd, lc_to_nv_map) + nv_float32 = lcdtype_to_nvdtype(dtypes.float32) + + out = fd.ops.matmul( + fd.ops.mul(fd.ops.cast(nva, nv_float32), nv_scalea), + fd.ops.mul(fd.ops.cast(nvb, nv_float32), nv_scaleb), + ) + if bias is not None: + out = fd.ops.add(out, getnv(bias, fd, lc_to_nv_map)) + + dtype = dtypes.to_dtype(out_dtype) if out_dtype is not None else a.dtype + nv_out_dtype = lcdtype_to_nvdtype(dtype) + out = fd.ops.cast(out, nv_out_dtype) + + return out + + +nv_decomposed_scaled_mm = ex.register_operator( + "nv_decomposed_scaled_mm", + meta=_decomposed_scaled_mm_meta, + fn=_decomposed_scaled_mm_impl, +) +register_supported(nv_decomposed_scaled_mm.id, _decomposed_scaled_mm_impl, None) + + +def _scaled_mm_check( + a: TensorLike, + b: TensorLike, + scale_a: TensorLike, + scale_b: TensorLike, + bias: TensorLike | None = None, + scale_result: TensorLike | None = None, + out_dtype: dtypes.dtype | None = None, + use_fast_accum: bool = False, +) -> bool: + if scale_result is not None or use_fast_accum: + return False + return True + + +def _scaled_mm_impl( + a: TensorLike, + b: TensorLike, + scale_a: TensorLike, + scale_b: TensorLike, + bias: TensorLike | None = None, + scale_result: TensorLike | None = None, + out_dtype: dtypes.dtype | None = None, + use_fast_accum: bool = False, +) -> bool: + return nv_decomposed_scaled_mm(a, b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum) + + +for sym_of_scaled_mm in (ltorch._scaled_mm, ltorch.core_aten_scaled_mm): + ex.register_supported( + sym_of_scaled_mm, + checker=_scaled_mm_check, + execution_transform=_scaled_mm_impl, + ) diff --git a/thunder/tests/test_tensor_subclass.py b/thunder/tests/test_tensor_subclass.py index 9abd0c11ad..bc728ff1d4 100644 --- a/thunder/tests/test_tensor_subclass.py +++ b/thunder/tests/test_tensor_subclass.py @@ -268,8 +268,8 @@ def test_torchao_float8_linear(executor, device, dtype, bias): model = nn.Sequential( nn.Linear(in_features, out_features, bias=bias), - nn.GELU(approximate="tanh"), - nn.Linear(out_features, out_features, bias=bias), + # nn.GELU(approximate="tanh"), + # nn.Linear(out_features, out_features, bias=bias), ).to(device=device, dtype=torch_dtype) fp8_model = convert_to_float8_training(model) x = make_tensor((batch_size, in_features), device=device, dtype=torch_dtype) @@ -305,6 +305,8 @@ def test_torchao_float8_linear(executor, device, dtype, bias): pytest.xfail("numerical error") torch.testing.assert_close(actual, expected) + actual.mean().backward() + # TODO(crcrpar): Think of how to push tensor subclasses to `thunder.jit`. # Currently no subgraphs go to thunder.jit. if is_thunderfx: diff --git a/thunder/transforms/tensor_wrapper_subclass.py b/thunder/transforms/tensor_wrapper_subclass.py index 1379a4bd4c..7916beb010 100644 --- a/thunder/transforms/tensor_wrapper_subclass.py +++ b/thunder/transforms/tensor_wrapper_subclass.py @@ -225,6 +225,14 @@ def torch_interpreted_func(*args, **kwargs): return torch_trace +def create_ctor_of_tensor_subclass(unflatten_method, tensor_names): + def ctor(tensors, metadata): + inner_tensors = dict(zip(tensor_names, tensors)) + return unflatten_method(inner_tensors, metadata, -1, -1) + + return ctor + + @dataclass class DesugarTensorSubclass: computation_trace: TraceCtx @@ -270,14 +278,7 @@ def _get_tensor_attr_names(self, p: SubclassTensorProxy) -> list[str]: def _get_non_tensor_attr_names(self, p: SubclassTensorProxy) -> list[str]: return p._non_tensor_attr_names - def translate_fx_graph_into_bsym( - self, - bsym: BoundSymbol, - fx_graph: GraphModule, - ) -> BoundSymbol | tuple[BoundSymbol, ...]: - import thunder.torch as ltorch - from thunder.torch import _torch_to_thunder_function_map - + def _unwrap_bsym_args(self, bsym: BoundSymbol) -> tuple[dict[int, ProxyInterface], list[BoundSymbol]]: unwrapped_bsym_args: dict[int, ProxyInterface] = {} list_of_flattening_bsyms: list[BoundSymbol] = [] for a in bsym.flat_args: @@ -311,36 +312,19 @@ def translate_fx_graph_into_bsym( with tracectx(self.computation_trace): a = proxy(a) unwrapped_bsym_args[len(unwrapped_bsym_args)] = a + return unwrapped_bsym_args, list_of_flattening_bsyms - node: Node - list_of_placeholder_node: list[Node] = [] - list_of_function_call_node: list[Node] = [] - node_of_output: Node - for node in fx_graph.graph.nodes: - if node.op == PLACEHOLDER: - list_of_placeholder_node.append(node) - if node.op == CALL_FUNCTION: - list_of_function_call_node.append(node) - if node.op == OUTPUT: - node_of_output = node - args = [n.target for n in list_of_placeholder_node] - arg_name_to_index = {a: i for i, a in enumerate(args)} - ltorch_ops_for_node_of_ops = [] - for node in list_of_function_call_node: - op: OpOverload = node.target - if op not in _torch_to_thunder_function_map: - msg = ( - f"`thunder.torch` does not have corresponding op for {op}. " - "Think about adding it to thunder/torch/default_torch_ops.py" - f"\nThe op is found while flattening the following BoundSymbol:\n{bsym}" - f"\ntorch.fx graph:\n{fx_graph.print_readable(print_output=False)}" - ) - raise RuntimeError(msg) - ltorch_ops_for_node_of_ops.append(_torch_to_thunder_function_map[op]) - + def _evaluate_ltorch_op( + self, + list_of_function_call_node: list[Node], + ltorch_ops_for_node_of_ops: list[Symbol], + unwrapped_bsym_args: dict[int, ProxyInterface], + arg_name_to_index: dict[str, int], + *, + _cur_bsym_for_error_msg: BoundSymbol, + _cur_fxgraph_for_error_msg: GraphModule, + ) -> list[BoundSymbol]: bsyms: list[BoundSymbol] = [] - if list_of_flattening_bsyms: - bsyms.extend(list_of_flattening_bsyms) fxnode_output_name_to_tensor_proxy: dict[str, OpOverload] = {} for node, ltorch_op in zip(list_of_function_call_node, ltorch_ops_for_node_of_ops): args: list[Node] = node.args @@ -369,73 +353,137 @@ def translate_fx_graph_into_bsym( msg = ( f"Failing to map `torch.{node}` to `thunder.torch` op of " f"{ltorch_op} with args of {arg_proxies}\n" - f"BoundSymbol in question is\n```python\n{bsym}\n```\n" - f"Corresponding torch.fx Graph is\n```python\n{fx_graph.print_readable(print_output=False)}\n```\n" + f"BoundSymbol in question is\n```python\n{_cur_bsym_for_error_msg}\n```\n" + f"Corresponding torch.fx Graph is\n```python\n{_cur_fxgraph_for_error_msg.print_readable(print_output=False)}\n```\n" f"Original error is {e}" ) raise type(e)(msg) else: fxnode_output_name_to_tensor_proxy[str(node)] = out bsyms.extend(self.computation_trace.pop_scope()) - if len(bsyms) == 0: + return bsyms, fxnode_output_name_to_tensor_proxy + + def _postprocess_subclass_from_torch_op( + self, + orig_output: SubclassTensorProxy, + node_of_output: Node, + arg_name_to_index: dict[str, int], + unwrapped_bsym_args: dict[int, ProxyInterface], + fxnode_output_name_to_tensor_proxy: dict[str, OpOverload], + ) -> tuple[BoundSymbol, SubclassTensorProxy]: + # note(crcrpar): args[0] would be list of tensors, and args[1] could be list of non-tensors. + args: list[Node] = node_of_output.args[0] + new_tensor_proxies = [] + new_non_tensor_values = [] + for a in args: + value = a + if isinstance(a, Node): + if isinstance(a.target, str): + value = unwrapped_bsym_args[arg_name_to_index[a.target]] + else: + value = fxnode_output_name_to_tensor_proxy[str(a)] + if isinstance(value, TensorProxy): + new_tensor_proxies.append(value) + elif isinstance(value, (immutable_dict, immutable_list)): + if isinstance(value, immutable_dict): + new_non_tensor_values.append(dict(value)) + else: + new_non_tensor_values.append(list(value)) + else: + new_non_tensor_values.append(value) + utils.check( + len(orig_output._tensors) == len(new_tensor_proxies), + lambda: ( + f"The number of new tensor proxies for {orig_output=} does not match: " + f"{len(new_tensor_proxies)=} != {len(orig_output._tensors)=}" + ), + ) + with tracectx(self.computation_trace): + new_subclass = orig_output.replace() + new_subclass._tensors = new_tensor_proxies + for name, value in zip(new_subclass._tensor_attr_names, new_tensor_proxies): + setattr(new_subclass, name, value) + return ( + prims.unflatten_tensor_subclass.bind( + new_subclass._subclass_type, + dict(zip(new_subclass._tensor_attr_names, new_tensor_proxies)), + dict(zip(new_subclass._non_tensor_attr_names, new_subclass._non_tensors)), + output=new_subclass, + ), + new_subclass, + ) + + def translate_fx_graph_into_bsym( + self, + bsym: BoundSymbol, + fx_graph: GraphModule, + header: str, + ) -> BoundSymbol | tuple[BoundSymbol, ...]: + from thunder.torch import _torch_to_thunder_function_map + + unwrapped_bsym_args, list_of_flattening_bsyms = self._unwrap_bsym_args(bsym) + + node: Node + list_of_placeholder_node: list[Node] = [] + list_of_function_call_node: list[Node] = [] + node_of_output: Node + for node in fx_graph.graph.nodes: + if node.op == PLACEHOLDER: + list_of_placeholder_node.append(node) + if node.op == CALL_FUNCTION: + list_of_function_call_node.append(node) + if node.op == OUTPUT: + node_of_output = node + args = [n.target for n in list_of_placeholder_node] + arg_name_to_index = {a: i for i, a in enumerate(args)} + ltorch_ops_for_node_of_ops = [] + for node in list_of_function_call_node: + op: OpOverload = node.target + if op not in _torch_to_thunder_function_map: + msg = ( + f"`thunder.torch` does not have corresponding op for {op}. " + "Think about adding it to thunder/torch/default_torch_ops.py" + f"\nThe op is found while flattening the following BoundSymbol:\n{bsym}" + f"\ntorch.fx graph:\n{fx_graph.print_readable(print_output=False)}" + ) + raise RuntimeError(msg) + ltorch_ops_for_node_of_ops.append(_torch_to_thunder_function_map[op]) + + bsyms: list[BoundSymbol] = [] + if list_of_flattening_bsyms: + bsyms.extend(list_of_flattening_bsyms) + bsyms_of_torch_ops, fxnode_output_name_to_tensor_proxy = self._evaluate_ltorch_op( + list_of_function_call_node, + ltorch_ops_for_node_of_ops, + unwrapped_bsym_args, + arg_name_to_index, + _cur_bsym_for_error_msg=bsym, + _cur_fxgraph_for_error_msg=fx_graph, + ) + bsyms.extend(bsyms_of_torch_ops) + if not bsyms: return [bsym] orig_output = bsym.flat_outs[0] if is_subclass_ctor_bsym := bsym.sym.id == prims.PrimIDs.TENSOR_SUBCLASS_CTOR: utils.check_type(orig_output, SubclassTensorProxy) if isinstance(orig_output, SubclassTensorProxy): - # note(crcrpar): args[0] would be list of tensors, and args[1] could be list of non-tensors. - args: list[Node] = node_of_output.args[0] - new_tensor_proxies = [] - new_non_tensor_values = [] - for a in args: - value = a - if isinstance(a, Node): - if isinstance(a.target, str): - value = unwrapped_bsym_args[arg_name_to_index[a.target]] - else: - value = fxnode_output_name_to_tensor_proxy[str(a)] - if isinstance(value, TensorProxy): - new_tensor_proxies.append(value) - elif isinstance(value, (immutable_dict, immutable_list)): - if isinstance(value, immutable_dict): - new_non_tensor_values.append(dict(value)) - else: - new_non_tensor_values.append(list(v)) - else: - new_non_tensor_values.append(value) - utils.check( - len(orig_output._tensors) == len(new_tensor_proxies), - lambda: ( - f"The number of new tensor proxies for {orig_output=} does not match: " - f"{len(new_tensor_proxies)=} != {len(orig_output._tensors)=}" - ), + unflatten_bsym, new_subclass_proxy = self._postprocess_subclass_from_torch_op( + orig_output, + node_of_output, + arg_name_to_index, + unwrapped_bsym_args, + fxnode_output_name_to_tensor_proxy, ) - with tracectx(self.computation_trace): - new_subclass = orig_output.replace() - new_subclass._tensors = new_tensor_proxies - for name, value in zip(new_subclass._tensor_attr_names, new_tensor_proxies): - setattr(new_subclass, name, value) - bsyms.append( - prims.unflatten_tensor_subclass.bind( - new_subclass._subclass_type, - dict(zip(new_subclass._tensor_attr_names, new_tensor_proxies)), - dict(zip(new_subclass._non_tensor_attr_names, new_subclass._non_tensors)), - output=new_subclass, - ) - ) - - self.swap_map[variableify(orig_output)] = new_subclass - self.subclass_proxy_to_flatten.add(variableify(new_subclass)) + bsyms.append(unflatten_bsym) + self.swap_map[variableify(orig_output)] = new_subclass_proxy + self.subclass_proxy_to_flatten.add(variableify(new_subclass_proxy)) else: non_none_args = [n for n in node_of_output.args[0] if n is not None] utils.check(len(non_none_args) == 1, lambda: f"{node_of_output.args = }") new_out_node = non_none_args[0] self.swap_map[variableify(orig_output)] = fxnode_output_name_to_tensor_proxy[str(new_out_node)] - - args = ", ".join([t.name if isinstance(t, ProxyInterface) else f"{t}" for t in bsym.flat_args]) - header = f"{bsym.sym.id}({args})" for i, sbsym in enumerate(bsyms, 1): sbsym.header = f"[{i}/{len(bsyms)}] unrolled `__torch_dispatch__` of `{header}`" return bsyms @@ -444,12 +492,6 @@ def convert_trace_to_fx_graph_and_get_fake_result( self, trace: TraceCtx, ) -> tuple[GraphModule, tuple[OutputWrapperForFxTracing, ...], tuple[torch.Tensor, ...], PyTreeSpec]: - def create_ctor(unflatten_method, tensor_names): - def ctor(tensors, metadata): - inner_tensors = dict(zip(tensor_names, tensors)) - return unflatten_method(inner_tensors, metadata, -1, -1) - - return ctor args = tree_map( lambda t: maybe_materialize_tensor( @@ -467,7 +509,9 @@ def ctor(tensors, metadata): desugared_args.extend([getattr(a, name) for name in attrs]) desugared_args.append(metadta) end_idx = len(desugared_args) - arg_idx_to_sugar[start_idx] = end_idx, create_ctor(type(a).__tensor_unflatten__, attrs) + arg_idx_to_sugar[start_idx] = end_idx, create_ctor_of_tensor_subclass( + type(a).__tensor_unflatten__, attrs + ) else: desugared_args.append(a) @@ -621,7 +665,9 @@ def __call__(self, bsym: BoundSymbol) -> list[BoundSymbol]: bsym_with_modified_output = updated_bsym.from_bsym_swap_proxies(self.swap_map) self.bsym_to_new_outputs[bsym_with_modified_output] = bsym_with_modified_output - return self.translate_fx_graph_into_bsym(bsym_with_modified_output, fx) + args = ", ".join([t.name if isinstance(t, ProxyInterface) else f"{t}" for t in bsym.flat_args]) + header = f"{bsym.sym.id}({args})" + return self.translate_fx_graph_into_bsym(bsym_with_modified_output, fx, header=header) def tensor_subclass_dce(trace: TraceCtx) -> TraceCtx: