Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
tripleMu committed May 28, 2022
1 parent 699dd7a commit 6ab4beb
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions mmcv/runner/fp16_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
pass


def cast_tensor_type(inputs: Union[torch.Tensor, nn.Module, str, np.ndarray,
abc.Mapping, abc.Iterable],
src_type: torch.dtype,
dst_type: torch.dtype) -> torch.Tensor:
def cast_tensor_type(
inputs: Union[torch.Tensor, nn.Module, str, np.ndarray, abc.Mapping,
abc.Iterable], src_type: torch.dtype, dst_type: torch.dtype
) -> Union[torch.Tensor, nn.Module, str, abc.Mapping, abc.Iterable]:
"""Recursively convert Tensor in inputs from src_type to dst_type.
Note:
Expand Down Expand Up @@ -56,13 +56,13 @@ def cast_tensor_type(inputs: Union[torch.Tensor, nn.Module, str, np.ndarray,
elif isinstance(inputs, np.ndarray):
return inputs
elif isinstance(inputs, abc.Mapping):
return type(inputs)({ # type: ignore
return type(inputs)({ # type:ignore
k: cast_tensor_type(v, src_type, dst_type)
for k, v in inputs.items()
})
elif isinstance(inputs, abc.Iterable):
return type(inputs)( # type: ignore
cast_tensor_type(item, src_type, dst_type) for item in inputs)
return type(inputs)(cast_tensor_type(item, src_type, dst_type)
for item in inputs) # type:ignore
else:
return inputs

Expand Down Expand Up @@ -360,7 +360,7 @@ class LossScaler:
"""

def __init__(self,
init_scale: float = 2**32,
init_scale: float = 2.**32,
mode: str = 'dynamic',
scale_factor: float = 2.,
scale_window: int = 1000):
Expand Down

0 comments on commit 6ab4beb

Please sign in to comment.