Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

内存占用溢出问题 #154

Open
xiaoxiaojiea opened this issue Mar 5, 2024 · 2 comments
Open

内存占用溢出问题 #154

xiaoxiaojiea opened this issue Mar 5, 2024 · 2 comments

Comments

@xiaoxiaojiea
Copy link

尊敬的开发者您好,首先感谢您们这项有价值的工作,但是在实际使用过程中我遇到了一些问题:训练过程中内存的占用在逐渐升高,直到最后训练被迫中止;我目前正在review源码检查问题,但是由于不是很熟悉代码,所以阅读的比较慢;同时在这里想了解一下您目前是否有解决方案,以解决当前的问题呢?

@xiaoxiaojiea
Copy link
Author

目前已解决,解决方法如下:

问题原因:fsmn.py 文件 class FSMN(nn.Module) 类的 forward 函数中的 torch.cat(in_cache, dim=-1) 会一直复制占用内存,导致内存不断升高。

更改方法:将这一行代码拆开这样写就可以了

# x7 = self.softmax(x6)
        x7, _ = x6
        # return x7, None

        # ===============================
        cat_size = sum(tensor.size(-1) for tensor in in_cache)
        ret_cache = torch.zeros([in_cache[0].shape[0], in_cache[0].shape[1], in_cache[0].shape[2], cat_size])

        for i in range(cat_size):
            ret_cache[:, :, :, i] = in_cache[i].detach().squeeze(-1)
        # ===============================


        return x7, ret_cache

@mlxu995
Copy link
Collaborator

mlxu995 commented Mar 11, 2024

我对 fsmn 不太熟,@duj12 靖哥你看这里的 detach() 会影响 fsmn 的训练吗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants