We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
当在列表推导式中转换 torch API 时,由于 BaseTransformer::insert_multi_node 的限制,无法正确插入必要的辅助代码(如 import 语句和 sys.path 设置)。 这个问题会影响所有在列表推导式或字典推导式中使用辅助函数的 torch API 的代码转换。
BaseTransformer::insert_multi_node
import torch a = torch.Tensor([[1.,2.], [3.,4.]]) list_a = [a,a] result = [torch.transpose(input=x, dim1=0, dim0=1) for x in list_a ]
import paddle a = paddle.to_tensor(data=[[1.0, 2.0], [3.0, 4.0]], dtype='float32') list_a = [a, a] result = [paddle.transpose(x=x, perm=paddle_aux.transpose_aux_func(x.ndim, 1, 0)) for x in list_a]
其中没有正常导入 paddle_aux
当转换器处理列表推导式中的 torch.transpose() 时,需要插入辅助代码。但是由于 insert_multi_node 方法中的如下检查,会拦截在列表推导式中插入新节点:
torch.transpose()
insert_multi_node
if isinstance(self.parent_node, (ast.DictComp, ast.ListComp)): return False
导致必要的辅助代码无法被正确插入。
修改 insert_multi_node 方法,将 import 相关的节点(包括 import 语句和 sys.path 设置)始终插入到文件顶部,而不受推导式作用域的限制。例如:
def insert_multi_node(self, node_list): if len(node_list) == 0: return True import_nodes = [] other_nodes = [] for node in node_list: if isinstance(node, (ast.Import, ast.ImportFrom)): import_nodes.append(node) elif "sys.path" in astor.to_source(node): import_nodes.append(node) else: other_nodes.append(node) # 始终将import相关的节点插入到文件顶部 if len(import_nodes) > 0: self.record_scope((self.root, "body", 0), import_nodes) # 其他节点仍然遵循原来的逻辑 if len(other_nodes) > 0: if isinstance(self.parent_node, (ast.DictComp, ast.ListComp)): return False self.record_scope(self.scope_body_index(), other_nodes) return True
The text was updated successfully, but these errors were encountered:
Successfully merging a pull request may close this issue.
问题描述
当在列表推导式中转换 torch API 时,由于
BaseTransformer::insert_multi_node
的限制,无法正确插入必要的辅助代码(如 import 语句和 sys.path 设置)。这个问题会影响所有在列表推导式或字典推导式中使用辅助函数的 torch API 的代码转换。
复现步骤
其中没有正常导入 paddle_aux
Bug 原因
当转换器处理列表推导式中的
torch.transpose()
时,需要插入辅助代码。但是由于insert_multi_node
方法中的如下检查,会拦截在列表推导式中插入新节点:导致必要的辅助代码无法被正确插入。
可能的修复方案
修改
insert_multi_node
方法,将 import 相关的节点(包括 import 语句和 sys.path 设置)始终插入到文件顶部,而不受推导式作用域的限制。例如:The text was updated successfully, but these errors were encountered: