diff --git a/mmdet/core/fp16/decorators.py b/mmdet/core/fp16/decorators.py index 10ffbf8..e1e9830 100644 --- a/mmdet/core/fp16/decorators.py +++ b/mmdet/core/fp16/decorators.py @@ -56,12 +56,9 @@ def new_func(*args, **kwargs): # NOTE: default args are not taken into consideration if args: arg_names = args_info.args[:len(args)] - for i, arg_name in enumerate(arg_names): - if arg_name in args_to_cast: - new_args.append( - cast_tensor_type(args[i], torch.float, torch.half)) - else: - new_args.append(args[i]) + new_args = [cast_tensor_type(args[i], torch.float, torch.half) + if arg_name in args_to_cast else args[i] + for (i, arg_name) in enumerate(arg_names)] # convert the kwargs that need to be processed new_kwargs = {} if kwargs: