File tree Expand file tree Collapse file tree 1 file changed +8
-6
lines changed Expand file tree Collapse file tree 1 file changed +8
-6
lines changed Original file line number Diff line number Diff line change 11import torch
22from functools import partial
3+ from typing import Callable
34
4-
5- def custom_amp_decorator (dec , cuda_amp_deprecated ):
6- def decorator (func ):
7- return dec (func ) if not cuda_amp_deprecated else partial (dec , func , device_type = "cuda" )
5+ def custom_amp_decorator (dec : Callable , cuda_amp_deprecated : bool ):
6+ def decorator (* args , ** kwargs ):
7+ if cuda_amp_deprecated :
8+ kwargs ["device_type" ] = "cuda"
9+ return dec (* args , ** kwargs )
810 return decorator
911
1012
11- if hasattr (torch .amp , "custom_fwd" ):
13+ if hasattr (torch .amp , "custom_fwd" ): # type: ignore[attr-defined]
1214 deprecated = True
13- from torch .amp import custom_fwd , custom_bwd
15+ from torch .amp import custom_fwd , custom_bwd # type: ignore[attr-defined]
1416else :
1517 deprecated = False
1618 from torch .cuda .amp import custom_fwd , custom_bwd
You can’t perform that action at this time.
0 commit comments