Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

在列表推导式中转换 torch API 时无法正确插入辅助代码 #521

Closed
guozixu2001 opened this issue Nov 27, 2024 · 0 comments · Fixed by #522
Closed

在列表推导式中转换 torch API 时无法正确插入辅助代码 #521

guozixu2001 opened this issue Nov 27, 2024 · 0 comments · Fixed by #522

Comments

@guozixu2001
Copy link
Contributor

问题描述

当在列表推导式中转换 torch API 时,由于 BaseTransformer::insert_multi_node 的限制,无法正确插入必要的辅助代码(如 import 语句和 sys.path 设置)。
这个问题会影响所有在列表推导式或字典推导式中使用辅助函数的 torch API 的代码转换。

复现步骤

  1. 转换如下的 PyTorch 代码:
import torch
a = torch.Tensor([[1.,2.], [3.,4.]])
list_a = [a,a]
result = [torch.transpose(input=x, dim1=0, dim0=1) for x in list_a ]
  1. 使用 PaConvert 得到的 Paddle 代码:
import paddle
a = paddle.to_tensor(data=[[1.0, 2.0], [3.0, 4.0]], dtype='float32')
list_a = [a, a]
result = [paddle.transpose(x=x, perm=paddle_aux.transpose_aux_func(x.ndim, 
    1, 0)) for x in list_a]

其中没有正常导入 paddle_aux

Bug 原因

当转换器处理列表推导式中的 torch.transpose() 时,需要插入辅助代码。但是由于 insert_multi_node 方法中的如下检查,会拦截在列表推导式中插入新节点:

if isinstance(self.parent_node, (ast.DictComp, ast.ListComp)):
    return False

导致必要的辅助代码无法被正确插入。

可能的修复方案

修改 insert_multi_node 方法,将 import 相关的节点(包括 import 语句和 sys.path 设置)始终插入到文件顶部,而不受推导式作用域的限制。例如:

def insert_multi_node(self, node_list):
    if len(node_list) == 0:
        return True

    import_nodes = []
    other_nodes = []
    for node in node_list:
        if isinstance(node, (ast.Import, ast.ImportFrom)):
            import_nodes.append(node)
        elif "sys.path" in astor.to_source(node):
            import_nodes.append(node)
        else:
            other_nodes.append(node)

    # 始终将import相关的节点插入到文件顶部
    if len(import_nodes) > 0:
        self.record_scope((self.root, "body", 0), import_nodes)

    # 其他节点仍然遵循原来的逻辑
    if len(other_nodes) > 0:
        if isinstance(self.parent_node, (ast.DictComp, ast.ListComp)):
            return False
        self.record_scope(self.scope_body_index(), other_nodes)

    return True
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant