File tree Expand file tree Collapse file tree 1 file changed +4
-4
lines changed
py/torch_tensorrt/dynamo/lowering Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -333,6 +333,7 @@ def scatter_reduce_decomposition(
333333 reduce : str ,
334334) -> torch .Tensor :
335335 scatter_loop_tensor = input_tensor
336+ device_input_tensor = input_tensor .device
336337 # required for mean reduce operation
337338 scatter_count_tensor = torch .zeros_like (input_tensor )
338339 src_shape = list (src_tensor .shape )
@@ -344,12 +345,11 @@ def scatter_reduce_decomposition(
344345 # unsqueeze src and index in dim
345346 src_slice = torch .unsqueeze (src_slice , dim )
346347 index_slice = torch .unsqueeze (index_slice , dim )
347- device = to_torch_device (default_device ())
348348
349349 # moving tensor to default device
350- scatter_loop_tensor = scatter_loop_tensor .to (device )
351- index_slice = index_slice .to (device )
352- src_slice = src_slice .to (device )
350+ scatter_loop_tensor = scatter_loop_tensor .to (device_input_tensor )
351+ index_slice = index_slice .to (device_input_tensor )
352+ src_slice = src_slice .to (device_input_tensor )
353353 if reduce == "sum" :
354354 reduceOp = ReduceOperation .SUM
355355 elif reduce == "prod" :
You can’t perform that action at this time.
0 commit comments