Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Warning] stft with return_complex=False is deprecated. #2639

Closed
Alienpups opened this issue May 28, 2023 · 6 comments · Fixed by eginhard/coqui-tts#20
Closed

[Warning] stft with return_complex=False is deprecated. #2639

Alienpups opened this issue May 28, 2023 · 6 comments · Fixed by eginhard/coqui-tts#20
Labels
bug Something isn't working wontfix This will not be worked on but feel free to help.

Comments

@Alienpups
Copy link

Describe the bug

If I start to train my model I always get this warning message. Is it just me or is this a thing?

TRAINING (2023-05-29 00:46:49)
C:\TTS-Training\lib\site-packages\torch\functional.py:641: UserWarning: stft with return_complex=False is deprecated. In a future pytorch release, stft will return complex tensors for all inputs, and return_complex=False will raise an error.
Note: you can still call torch.view_as_real on the complex output to recover the old return format. (Triggered internally at ..\aten\src\ATen\native\SpectralOps.cpp:867.)
return _VF.stft(input, n_fft, hop_length, win_length, window, # type: ignore[attr-defined]

To Reproduce

I run this command: C:\TTS-Training>python train_vits.py --continue_path "vits_louise-voice-April-21-2023_05+38PM-0000000".
Training starts but I always get warned about some sft error.

TRAINING (2023-05-29 00:46:49)
C:\TTS-Training\lib\site-packages\torch\functional.py:641: UserWarning: stft with return_complex=False is deprecated. In a future pytorch release, stft will return complex tensors for all inputs, and return_complex=False will raise an error.
Note: you can still call torch.view_as_real on the complex output to recover the old return format. (Triggered internally at ..\aten\src\ATen\native\SpectralOps.cpp:867.)
return _VF.stft(input, n_fft, hop_length, win_length, window, # type: ignore[attr-defined]

Expected behavior

No response

Logs

No response

Environment

{
    "CUDA": {
        "GPU": [],
        "available": false,
        "version": null
    },
    "Packages": {
        "PyTorch_debug": false,
        "PyTorch_version": "2.0.0+cpu",
        "TTS": "0.14.0",
        "numpy": "1.23.5"
    },
    "System": {
        "OS": "Windows",
        "architecture": [
            "64bit",
            "WindowsPE"
        ],
        "processor": "Intel64 Family 6 Model 151 Stepping 2, GenuineIntel",
        "python": "3.10.11",
        "version": "10.0.22631"
    }
}

Additional context

No response

@Alienpups Alienpups added the bug Something isn't working label May 28, 2023
@Edresson
Copy link
Contributor

Edresson commented Jun 15, 2023

Hi @Alienpups,

Yeah, it is a thing.

To fix it, we need to change our STFT class to use return_complex=True and replace:

M = o[:, :, :, 0]
P = o[:, :, :, 1]
S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8))

by

o = torch.view_as_real(o)
S = torch.norm(o, p=2, dim=-1)

A PR would be welcome :).

@Alienpups
Copy link
Author

Alienpups commented Jun 23, 2023

To fix this warning, do I just need to make the changes you mentioned in torch_transforms.py? Did I understand that correctly? And what exactly does PR mean? :)

So I replace these three lines...
M = o[:, :, :, 0]
P = o[:, :, :, 1]
S = torch.sqrt(torch.clamp(M2 + P2, min=1e-8))

with the following? ...
o = torch.view_as_real(o)
S = torch.norm(o, p=2, dim=-1)

@Edresson
Copy link
Contributor

To fix this warning, do I just need to make the changes you mentioned in torch_transforms.py? Did I understand that correctly? And what exactly does PR mean? :)

So I replace these three lines... M = o[:, :, :, 0] P = o[:, :, :, 1] S = torch.sqrt(torch.clamp(M2 + P2, min=1e-8))

with the following? ... o = torch.view_as_real(o) S = torch.norm(o, p=2, dim=-1)

Yes, you also need to use "return_complex=True" on the torch.stft function call.

Sorry, PR means pull request. If you like you can do these changes and then send a pull request and turn a 🐸 TTS contributor :).

@stale
Copy link

stale bot commented Jul 24, 2023

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. You might also look our discussion channels.

@stale stale bot added the wontfix This will not be worked on but feel free to help. label Jul 24, 2023
@pivolan
Copy link

pivolan commented Jul 30, 2023

I have the same problem, and fix by change edit file python3.10/site-packages/torch/functional.py
at line number 641, just insert this before return statement:
you should find this string:

return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore[attr-defined]

and insert this before:


    if not return_complex:
        return torch.view_as_real(_VF.stft(input, n_fft, hop_length, win_length, window, normalized, onesided, return_complex=True))

image

@stale stale bot removed the wontfix This will not be worked on but feel free to help. label Jul 30, 2023
@stale
Copy link

stale bot commented Sep 7, 2023

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. You might also look our discussion channels.

@stale stale bot added the wontfix This will not be worked on but feel free to help. label Sep 7, 2023
@stale stale bot closed this as completed Sep 14, 2023
jmaty added a commit to jmaty/Coqui-TTS that referenced this issue Jan 6, 2024
…coqui-ai#2639

Resolve "UserWarning: The use of `x.T` on tensors of dimension other than 2 to reverse their shape is deprecated" by using `x.mT` instead of `x.T`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working wontfix This will not be worked on but feel free to help.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants