diff --git a/mmcv/device/mlu/scatter_gather.py b/mmcv/device/mlu/scatter_gather.py index a3cbbd3a79..0b0c9b96f5 100644 --- a/mmcv/device/mlu/scatter_gather.py +++ b/mmcv/device/mlu/scatter_gather.py @@ -16,7 +16,7 @@ def scatter_map(obj): if isinstance(obj, torch.Tensor): if target_mlus != [-1]: obj = obj.to('mlu') - return obj + return [obj] else: # for CPU inference we use self-implemented scatter return Scatter.forward(target_mlus, obj)