Skip to content

[ET-VK][EZ] Shorten torch.fx.Node to Node #2403

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions backends/vulkan/serialization/vulkan_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch.fx import Node

_ScalarType = Union[int, bool, float]
_Argument = Union[torch.fx.Node, int, bool, float, str]
_Argument = Union[Node, int, bool, float, str]


class VkGraphBuilder:
Expand All @@ -29,7 +29,7 @@ def __init__(self, program: ExportedProgram) -> None:
self.output_ids = []
self.const_tensors = []

# Mapping from torch.fx.Node to VkValue id
# Mapping from Node to VkValue id
self.node_to_value_ids = {}

@staticmethod
Expand All @@ -39,18 +39,18 @@ def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType:
else:
raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})")

def is_constant(self, node: torch.fx.Node):
def is_constant(self, node: Node):
return (
node.name in self.program.graph_signature.inputs_to_lifted_tensor_constants
)

def is_get_attr_node(self, node: torch.fx.Node) -> bool:
def is_get_attr_node(self, node: Node) -> bool:
"""
Returns true if the given node is a get attr node for a tensor of the model
"""
return isinstance(node, torch.fx.Node) and node.op == "get_attr"
return isinstance(node, Node) and node.op == "get_attr"

def is_param_node(self, node: torch.fx.Node) -> bool:
def is_param_node(self, node: Node) -> bool:
"""
Check if the given node is a parameter within the exported program
"""
Expand All @@ -61,7 +61,7 @@ def is_param_node(self, node: torch.fx.Node) -> bool:
or self.is_constant(node)
)

def get_constant(self, node: torch.fx.Node) -> Optional[torch.Tensor]:
def get_constant(self, node: Node) -> Optional[torch.Tensor]:
"""
Returns the constant associated with the given node in the exported program.
Returns None if the node is not a constant within the exported program
Expand All @@ -79,7 +79,7 @@ def get_constant(self, node: torch.fx.Node) -> Optional[torch.Tensor]:

return None

def get_param_tensor(self, node: torch.fx.Node) -> torch.Tensor:
def get_param_tensor(self, node: Node) -> torch.Tensor:
tensor = None
if node is None:
raise RuntimeError("node is None")
Expand Down Expand Up @@ -168,7 +168,7 @@ def create_string_value(self, string: str) -> int:
return new_id

def get_or_create_value_for(self, arg: _Argument):
if isinstance(arg, torch.fx.Node):
if isinstance(arg, Node):
# If the value has already been created, return the existing id
if arg in self.node_to_value_ids:
return self.node_to_value_ids[arg]
Expand Down