diff --git a/graph_net/torch/backend/unstable_to_stable_backend.py b/graph_net/torch/backend/unstable_to_stable_backend.py index a0115d3ad..c85497ab5 100644 --- a/graph_net/torch/backend/unstable_to_stable_backend.py +++ b/graph_net/torch/backend/unstable_to_stable_backend.py @@ -1,8 +1,6 @@ import os import torch import sys -import inspect -import ast from .graph_compiler_backend import GraphCompilerBackend from ..fx_graph_serialize_util import serialize_graph_module_to_str @@ -318,7 +316,26 @@ def replace_in_graph(graph_mod): return gm - # replace this line with modification code for task 126 (torch._C._nn.scaled_dot_product_attention) + def _impl_unstable_to_stable_sdpa(self, gm): + """ + Convert torch._C._nn.scaled_dot_product_attention to torch.nn.functional.scaled_dot_product_attention + """ + issue_nodes = ( + node + for node in gm.graph.nodes + if node.op == "call_function" + if hasattr(node.target, "__module__") + if node.target.__module__ == "torch._C._nn" + if hasattr(node.target, "__name__") + if node.target.__name__ == "scaled_dot_product_attention" + ) + + for node in issue_nodes: + node.target = torch.nn.functional.scaled_dot_product_attention + + gm.recompile() + + return gm def _impl_unstable_to_stable_linear_to_functional_linear(self, gm): """ diff --git a/graph_net/torch/fx_graph_serialize_util.py b/graph_net/torch/fx_graph_serialize_util.py index e8f71c8ac..10a79d9a7 100644 --- a/graph_net/torch/fx_graph_serialize_util.py +++ b/graph_net/torch/fx_graph_serialize_util.py @@ -148,7 +148,10 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str: # replace this line with modification code for task 122 (torch._C._log_api_usage_once) (r"torch\._C\._nn\.pad\(", "torch.nn.functional.pad("), (r"torch\._C\._nn\.gelu\(", "torch.nn.functional.gelu("), - # replace this line with modification code for task 126 (torch._C._nn.scaled_dot_product_attention) + ( + r"torch\._C\._nn\.scaled_dot_product_attention\(", + "torch.nn.functional.scaled_dot_product_attention(", + ), (r"torch\._C\._nn\.linear\(", "torch.nn.functional.linear("), ] for pattern, repl in replacements: