diff --git a/fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu index 33fe178f20..3b6df599dc 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu @@ -247,8 +247,8 @@ Tensor {{ ddesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( TENSOR_ON_CUDA_GPU(feature_requires_grad); } - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(dev_weights.get_device()); + CUDA_DEVICE_GUARD(dev_weights); + const auto T = D_offsets.size(0) - 1; TORCH_CHECK_GT(T, 0); // offsets = [B x T + 1] diff --git a/fbgemm_gpu/codegen/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_template.cu index 875ed32dbf..f2eb67c43a 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_template.cu @@ -427,8 +427,7 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e } {%- endif %} - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(dev_weights.get_device()); + CUDA_DEVICE_GUARD(dev_weights); {%- if nobag and not is_index_select %} auto max_D = D; diff --git a/fbgemm_gpu/codegen/embedding_bounds_check.cu b/fbgemm_gpu/codegen/embedding_bounds_check.cu index a93be2ba6c..0e89b57617 100644 --- a/fbgemm_gpu/codegen/embedding_bounds_check.cu +++ b/fbgemm_gpu/codegen/embedding_bounds_check.cu @@ -190,8 +190,7 @@ void bounds_check_indices_cuda( TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( rows_per_table, indices, offsets, warning, weights, B_offsets); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(rows_per_table.get_device()); + CUDA_DEVICE_GUARD(rows_per_table); const int32_t T = rows_per_table.size(0); const int32_t total_B = offsets.size(0) - 1; diff --git a/fbgemm_gpu/codegen/embedding_forward_quantized_split_lookup.cu b/fbgemm_gpu/codegen/embedding_forward_quantized_split_lookup.cu index 4dcda7bf7e..00f1cc1f5c 100644 --- a/fbgemm_gpu/codegen/embedding_forward_quantized_split_lookup.cu +++ b/fbgemm_gpu/codegen/embedding_forward_quantized_split_lookup.cu @@ -140,8 +140,8 @@ Tensor pruned_hashmap_lookup_cuda( TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( indices, offsets, hash_table, hash_table_offsets); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(indices.get_device()); + CUDA_DEVICE_GUARD(indices); + auto dense_indices = at::empty_like(indices); const int32_t T = hash_table_offsets.size(0) - 1; const int32_t B = (offsets.size(0) - 1) / T; @@ -179,8 +179,8 @@ Tensor pruned_array_lookup_cuda( TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( indices, offsets, index_remappings, index_remappings_offsets); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(indices.get_device()); + CUDA_DEVICE_GUARD(indices); + auto dense_indices = at::empty_like(indices); const int32_t T = index_remappings_offsets.size(0) - 1; TORCH_CHECK( diff --git a/fbgemm_gpu/codegen/embedding_forward_quantized_split_nbit_host_template.cu b/fbgemm_gpu/codegen/embedding_forward_quantized_split_nbit_host_template.cu index 2d70981fe7..e556cc241c 100644 --- a/fbgemm_gpu/codegen/embedding_forward_quantized_split_nbit_host_template.cu +++ b/fbgemm_gpu/codegen/embedding_forward_quantized_split_nbit_host_template.cu @@ -107,8 +107,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_weights, dev_weights); TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_locations, dev_weights); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(dev_weights.get_device()); + CUDA_DEVICE_GUARD(dev_weights); // kernels assume indices are contiguous. indices = indices.contiguous(); diff --git a/fbgemm_gpu/codegen/embedding_forward_split_template.cu b/fbgemm_gpu/codegen/embedding_forward_split_template.cu index 1aabf9d0ec..91be4dc0d9 100644 --- a/fbgemm_gpu/codegen/embedding_forward_split_template.cu +++ b/fbgemm_gpu/codegen/embedding_forward_split_template.cu @@ -364,8 +364,7 @@ batch_index_select_dim0_codegen_forward_cuda( } {%- endif %} - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(dev_weights.get_device()); + CUDA_DEVICE_GUARD(dev_weights); {%- if not nobag %} int32_t T = D_offsets.numel() - 1; diff --git a/fbgemm_gpu/codegen/embedding_optimizer_split_template.cu b/fbgemm_gpu/codegen/embedding_optimizer_split_template.cu index ec1accd903..613b3a670d 100644 --- a/fbgemm_gpu/codegen/embedding_optimizer_split_template.cu +++ b/fbgemm_gpu/codegen/embedding_optimizer_split_template.cu @@ -82,8 +82,7 @@ void split_embedding_{{ optimizer }}_update( return; } - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(dev_weights.get_device()); + CUDA_DEVICE_GUARD(dev_weights); // Flatten dev_weights because it is currrently 2D dev_weights = dev_weights.flatten(); diff --git a/fbgemm_gpu/src/embedding_inplace_ops/embedding_inplace_update.cu b/fbgemm_gpu/src/embedding_inplace_ops/embedding_inplace_update.cu index 018798aec6..fcd9073d97 100644 --- a/fbgemm_gpu/src/embedding_inplace_ops/embedding_inplace_update.cu +++ b/fbgemm_gpu/src/embedding_inplace_ops/embedding_inplace_update.cu @@ -133,8 +133,7 @@ void embedding_inplace_update_cuda( lxu_cache_weights, lxu_cache_locations); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(dev_weights.get_device()); + CUDA_DEVICE_GUARD(dev_weights); const int64_t N = update_row_idx.numel(); if (N == 0) { @@ -226,9 +225,8 @@ Tensor pruned_array_lookup_from_row_idx_cuda( update_table_indices, index_remappings, index_remappings_offsets); + CUDA_DEVICE_GUARD(update_table_indices); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(update_table_indices.get_device()); auto dense_indices = at::empty_like(update_row_indices); const int32_t T = index_remappings_offsets.size(0) - 1; diff --git a/fbgemm_gpu/src/histogram_binning_calibration_ops.cu b/fbgemm_gpu/src/histogram_binning_calibration_ops.cu index ca7a8a46e9..6c98261b48 100644 --- a/fbgemm_gpu/src/histogram_binning_calibration_ops.cu +++ b/fbgemm_gpu/src/histogram_binning_calibration_ops.cu @@ -64,9 +64,7 @@ std::tuple histogram_binning_calibration_cuda( TENSOR_ON_CUDA_GPU(bin_num_examples); TENSOR_ON_CUDA_GPU(bin_num_positives); TORCH_CHECK_EQ(bin_num_examples.numel(), bin_num_positives.numel()); - - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(logit.get_device()); + CUDA_DEVICE_GUARD(logit); Tensor calibrated_prediction = at::empty_like(logit); Tensor bin_ids = at::empty({logit.numel()}, logit.options().dtype(at::kLong)); @@ -188,9 +186,7 @@ std::tuple histogram_binning_calibration_by_feature_cuda( TENSOR_ON_CUDA_GPU(bin_num_examples); TENSOR_ON_CUDA_GPU(bin_num_positives); TORCH_CHECK_EQ(bin_num_examples.numel(), bin_num_positives.numel()); - - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(logit.get_device()); + CUDA_DEVICE_GUARD(logit); // Convert lengths to offsets for better handling on GPUs. const auto segment_lengths_packed = segment_lengths.contiguous(); @@ -351,9 +347,7 @@ generic_histogram_binning_calibration_by_feature_cuda( TORCH_CHECK( bin_num_examples.numel() == (num_segments + 1) * (bin_boundaries.numel() + 1)); - - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(logit.get_device()); + CUDA_DEVICE_GUARD(logit); // Convert lengths to offsets for better handling on GPUs. const auto segment_lengths_packed = segment_lengths.contiguous(); diff --git a/fbgemm_gpu/src/input_combine_ops/input_combine.cu b/fbgemm_gpu/src/input_combine_ops/input_combine.cu index 3f799c9d27..d944418864 100644 --- a/fbgemm_gpu/src/input_combine_ops/input_combine.cu +++ b/fbgemm_gpu/src/input_combine_ops/input_combine.cu @@ -105,8 +105,7 @@ std::tuple tbe_input_combine_with_length_cuda( const uint64_t max_list_size, const c10::DeviceIndex& device) { constexpr uint32_t IS_LONG_NUM_BITS = 32; - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(device); + at::cuda::OptionalCUDAGuard device_guard(device); // combined_indices and combined_legnths are int tensors const auto int_options = at::TensorOptions().dtype(at::kInt).device( diff --git a/fbgemm_gpu/src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_backward.cu b/fbgemm_gpu/src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_backward.cu index 9517a2567c..1160bfee39 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_backward.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_backward.cu @@ -91,9 +91,7 @@ std::tuple batched_dense_vec_jagged_2d_mul_backward( const Tensor& a_values, const Tensor& a_offsets) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(grad_output, a_values, a_offsets, v); - - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(grad_output.get_device()); + CUDA_DEVICE_GUARD(grad_output); const int B = a_offsets.numel() - 1; const int D = grad_output.size(-1); diff --git a/fbgemm_gpu/src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_forward.cu b/fbgemm_gpu/src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_forward.cu index b0a9be2794..fc38bd3f05 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_forward.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_forward.cu @@ -56,9 +56,7 @@ Tensor batched_dense_vec_jagged_2d_mul_forward( const Tensor& a_values, const Tensor& a_offsets) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(v, a_values, a_offsets); - - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(v.get_device()); + CUDA_DEVICE_GUARD(v); const int B = a_offsets.numel() - 1; TORCH_CHECK( diff --git a/fbgemm_gpu/src/jagged_tensor_ops/dense_to_jagged_forward.cu b/fbgemm_gpu/src/jagged_tensor_ops/dense_to_jagged_forward.cu index ffbf6c41a0..1caf70bb0e 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/dense_to_jagged_forward.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/dense_to_jagged_forward.cu @@ -29,8 +29,7 @@ Tensor dense_to_jagged_forward( auto values = at::empty_symint({total_L_computed, D}, dense.options()); auto output = at::empty_like(values); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(dense.get_device()); + CUDA_DEVICE_GUARD(dense); #define DISPATCH_DENSE_TO_JAGGED_CASE(TYPE) \ AT_DISPATCH_CASE(TYPE, [&] { \ diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_dense_bmm_forward.cu b/fbgemm_gpu/src/jagged_tensor_ops/jagged_dense_bmm_forward.cu index dc40aec8dd..fd2992fec6 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_dense_bmm_forward.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_dense_bmm_forward.cu @@ -156,9 +156,7 @@ Tensor jagged_dense_bmm_forward_cuda( const Tensor& y, const int64_t max_L) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(x_values, x_offsets, y); - - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(x_values.get_device()); + CUDA_DEVICE_GUARD(x_values); const int B = x_offsets.numel() - 1; const int M = x_values.size(-1); diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu b/fbgemm_gpu/src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu index dc19378547..3b94e62ab6 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu @@ -218,8 +218,7 @@ Tensor jagged_dense_dense_elementwise_add_jagged_output_forward( TORCH_CHECK_EQ(dense_0.sizes(), dense_1.sizes()); auto output = at::empty_like(x_values); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(dense_0.get_device()); + CUDA_DEVICE_GUARD(dense_0); if (x_values.scalar_type() == at::ScalarType::BFloat16 && dense_0.scalar_type() == at::ScalarType::BFloat16 && diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_dense_elementwise_mul_backward.cu b/fbgemm_gpu/src/jagged_tensor_ops/jagged_dense_elementwise_mul_backward.cu index 41fb56a899..4a45c87158 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_dense_elementwise_mul_backward.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_dense_elementwise_mul_backward.cu @@ -128,8 +128,7 @@ std::tuple jagged_dense_elementwise_mul_backward( const std::vector& x_offsets, const Tensor& y, const Tensor& x_values) { - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(grad_output.get_device()); + CUDA_DEVICE_GUARD(grad_output); Tensor x_values_grad = at::empty_like(grad_output); Tensor y_grad = at::empty_like(y); diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_dense_elementwise_mul_forward.cu b/fbgemm_gpu/src/jagged_tensor_ops/jagged_dense_elementwise_mul_forward.cu index 7149e47826..296df03e07 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_dense_elementwise_mul_forward.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_dense_elementwise_mul_forward.cu @@ -16,8 +16,7 @@ Tensor jagged_dense_elementwise_mul_forward( const Tensor& x_values, const std::vector& x_offsets, const Tensor& y) { - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(x_values.get_device()); + CUDA_DEVICE_GUARD(x_values); Tensor output = at::empty_like(x_values); diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_index_add_2d_forward.cu b/fbgemm_gpu/src/jagged_tensor_ops/jagged_index_add_2d_forward.cu index 30b604d072..5d92891a30 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_index_add_2d_forward.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_index_add_2d_forward.cu @@ -78,9 +78,7 @@ Tensor jagged_index_add_2d_forward_cuda( const int64_t num_output_rows) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( values, indices, input_offsets, output_offsets); - - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(values.get_device()); + CUDA_DEVICE_GUARD(values); auto num_cols = values.size(1); diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_index_select_2d_forward.cu b/fbgemm_gpu/src/jagged_tensor_ops/jagged_index_select_2d_forward.cu index 618a4cd30f..c0c064a0aa 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_index_select_2d_forward.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_index_select_2d_forward.cu @@ -74,9 +74,7 @@ Tensor jagged_index_select_2d_forward_cuda( const int64_t num_dense_output_rows) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( values, indices, input_offsets, output_offsets); - - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(values.get_device()); + CUDA_DEVICE_GUARD(values); auto num_cols = values.size(1); diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_jagged_bmm_forward.cu b/fbgemm_gpu/src/jagged_tensor_ops/jagged_jagged_bmm_forward.cu index 55c75a8e03..7acd5b7f3f 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_jagged_bmm_forward.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_jagged_bmm_forward.cu @@ -162,9 +162,7 @@ Tensor jagged_jagged_bmm_forward_cuda( const Tensor& offsets, const int64_t max_L) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(x_values, y_values, offsets); - - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(x_values.get_device()); + CUDA_DEVICE_GUARD(x_values); const int B = offsets.numel() - 1; const int M = x_values.size(-1); diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_softmax_backward.cu b/fbgemm_gpu/src/jagged_tensor_ops/jagged_softmax_backward.cu index 76c2cad3ca..07cbed5d1f 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_softmax_backward.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_softmax_backward.cu @@ -96,9 +96,7 @@ Tensor jagged_softmax_backward_cuda( const Tensor& offsets, const int64_t max_L) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(grad_output, output, offsets); - - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(grad_output.get_device()); + CUDA_DEVICE_GUARD(grad_output); const auto B = offsets.numel() - 1; const auto D = grad_output.size(1); diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_softmax_forward.cu b/fbgemm_gpu/src/jagged_tensor_ops/jagged_softmax_forward.cu index 725431c368..18a6da4a92 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_softmax_forward.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_softmax_forward.cu @@ -119,9 +119,7 @@ Tensor jagged_softmax_forward_cuda( const Tensor& offsets, const int64_t max_L) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(values, offsets); - - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(values.get_device()); + CUDA_DEVICE_GUARD(values); const auto B = offsets.numel() - 1; const auto D = values.size(1); diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_to_padded_dense_backward.cu b/fbgemm_gpu/src/jagged_tensor_ops/jagged_to_padded_dense_backward.cu index e460652fab..cbccc29e24 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_to_padded_dense_backward.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_to_padded_dense_backward.cu @@ -17,8 +17,7 @@ at::Tensor jagged_to_padded_dense_backward( const std::vector& offsets, at::SymInt total_L) { auto grad_padded_values = grad_output; - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(grad_padded_values.get_device()); + CUDA_DEVICE_GUARD(grad_padded_values); // Canonicalize padded_values by unsqueeze the last dim if the inner dense // dimension is 1 and folded. diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_to_padded_dense_forward.cu b/fbgemm_gpu/src/jagged_tensor_ops/jagged_to_padded_dense_forward.cu index 8689d23939..ec3d06240f 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_to_padded_dense_forward.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_to_padded_dense_forward.cu @@ -30,8 +30,7 @@ at::Tensor jagged_to_padded_dense_forward( max_lengths.size(), " != num_jagged_dim, ", num_jagged_dim); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(values.get_device()); + CUDA_DEVICE_GUARD(values); const Tensor values_canonicalized = values.view( {values.size(0), @@ -83,8 +82,7 @@ std::vector stacked_jagged_1d_to_dense_gpu( int64_t padding_value) { TORCH_CHECK(values.dim() == 1); TORCH_CHECK(lengths.dim() == 2); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(values.get_device()); + CUDA_DEVICE_GUARD(values); const auto lengths_contig = lengths.contiguous(); int32_t B = lengths.size(1); @@ -138,8 +136,7 @@ stacked_jagged_2d_to_dense_forward_cuda( int64_t padding_value) { TORCH_CHECK(values.dim() == 2); TORCH_CHECK(lengths.dim() == 2); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(values.get_device()); + CUDA_DEVICE_GUARD(values); const auto lengths_contig = lengths.contiguous(); int32_t D = values.size(1); @@ -194,8 +191,7 @@ Tensor stacked_jagged_2d_to_dense_backward_cuda( const std::vector& grad_padded_values_per_key, const std::vector& offsets_tensor_per_key, const std::vector& offset_per_key) { - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(grad_padded_values_per_key[0].get_device()); + CUDA_DEVICE_GUARD(grad_padded_values_per_key[0]); auto grad_values = at::zeros({total_L, D}, grad_padded_values_per_key[0].options()); @@ -321,8 +317,7 @@ class JaggedDenseAddJaggedOutputGPUOp auto output = at::empty_like(x_values); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(dense.get_device()); + CUDA_DEVICE_GUARD(dense); AT_DISPATCH_SWITCH( x_values.scalar_type(), @@ -364,9 +359,7 @@ class JaggedDenseAddJaggedOutputGPUOp auto offsets = ctx->get_saved_variables(); auto dense_shape = ctx->saved_data["dense_shape"].toIntVector(); TORCH_CHECK(grad_outputs.size() == 1); - - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(grad_outputs[0].get_device()); + CUDA_DEVICE_GUARD(grad_outputs[0]); Tensor dense_values_grad = jagged_to_padded_dense_forward( grad_outputs[0], diff --git a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu index c588e4ef3a..989e41ce88 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu @@ -194,8 +194,7 @@ class KeyedJaggedIndexSelectDim1GPUOp "weights size and values size must be the same"); } - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(values.get_device()); + CUDA_DEVICE_GUARD(values); const int num_batches = lengths.numel() / batch_size; const int num_output_lengths = num_batches * indices.numel(); @@ -380,8 +379,7 @@ class KeyedJaggedIndexSelectDim1GPUOp int64_t output_batch_size = ctx->saved_data["batch_size"].toInt(); int64_t num_batches = ctx->saved_data["num_batches"].toInt(); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(grad.get_device()); + CUDA_DEVICE_GUARD(grad); Tensor grad_input = at::zeros({num_outputs}, grad.options()); auto grid_size = cuda_calc_xblock_count(grad.numel(), kMaxThreads); diff --git a/fbgemm_gpu/src/layout_transform_ops/layout_transform_ops.cu b/fbgemm_gpu/src/layout_transform_ops/layout_transform_ops.cu index 110775a999..5b696c221f 100644 --- a/fbgemm_gpu/src/layout_transform_ops/layout_transform_ops.cu +++ b/fbgemm_gpu/src/layout_transform_ops/layout_transform_ops.cu @@ -37,8 +37,7 @@ Tensor recat_embedding_grad_output_cuda( const std::vector& num_features_per_rank) { TENSOR_ON_CUDA_GPU(grad_output); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(grad_output.get_device()); + CUDA_DEVICE_GUARD(grad_output); TORCH_CHECK(grad_output.is_contiguous()); const auto B_local = grad_output.size(0); @@ -82,8 +81,7 @@ Tensor recat_embedding_grad_output_mixed_D_cuda( TENSOR_ON_CUDA_GPU(grad_output); TORCH_CHECK(grad_output.is_contiguous()); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(grad_output.get_device()); + CUDA_DEVICE_GUARD(grad_output); const auto B_local = grad_output.size(0); const auto global_dim_sum = at::sum_integers(dim_sum_per_rank); @@ -129,8 +127,7 @@ Tensor recat_embedding_grad_output_mixed_D_batch_cuda( grad_output, dim_sum_per_rank, cumsum_dim_sum_per_rank); TORCH_CHECK(grad_output.is_contiguous()); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(grad_output.get_device()); + CUDA_DEVICE_GUARD(grad_output); const auto B_local = grad_output.size(0); Tensor sharded_grad_output = diff --git a/fbgemm_gpu/src/memory_utils/memory_utils.cu b/fbgemm_gpu/src/memory_utils/memory_utils.cu index 83c64a30bb..46ad469495 100644 --- a/fbgemm_gpu/src/memory_utils/memory_utils.cu +++ b/fbgemm_gpu/src/memory_utils/memory_utils.cu @@ -7,6 +7,7 @@ */ #include "common.cuh" +#include "fbgemm_gpu/fbgemm_cuda_utils.cuh" using namespace at; @@ -32,8 +33,7 @@ struct CUDAHostMappedContext { : ptr_(ptr), cuda_device_(cuda_device){}; ~CUDAHostMappedContext() { - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(cuda_device_); + at::cuda::OptionalCUDAGuard device_guard(cuda_device_); AT_CUDA_CHECK(cudaHostUnregister(ptr_)); free(ptr_); } @@ -51,8 +51,7 @@ struct CUDAManagedContext { : ptr_(ptr), cuda_device_(cuda_device){}; ~CUDAManagedContext() { - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(cuda_device_); + at::cuda::OptionalCUDAGuard device_guard(cuda_device_); AT_CUDA_CHECK(cudaFree(ptr_)); } @@ -88,8 +87,7 @@ std::vector defaultStrides(IntArrayRef sizes) { Tensor new_managed_tensor_internal( const Tensor& self, const std::vector& sizes) { - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(self.get_device()); + CUDA_DEVICE_GUARD(self); auto strides = defaultStrides(sizes); size_t size_bytes = @@ -150,8 +148,7 @@ std::tuple adjust_to_page_boundaries(void* ptr, size_t size) { Tensor new_managed_tensor( const Tensor& self, const std::vector& sizes) { - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(self.get_device()); + CUDA_DEVICE_GUARD(self); Tensor t = new_managed_tensor_internal(self, sizes); @@ -187,8 +184,7 @@ Tensor new_managed_tensor_meta( Tensor new_vanilla_managed_tensor( const Tensor& self, const std::vector& sizes) { - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(self.get_device()); + CUDA_DEVICE_GUARD(self); return new_managed_tensor_internal(self, sizes); } @@ -196,8 +192,7 @@ Tensor new_vanilla_managed_tensor( Tensor new_host_mapped_tensor( const Tensor& self, const std::vector& sizes) { - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(self.get_device()); + CUDA_DEVICE_GUARD(self); auto strides = defaultStrides(sizes); size_t size_bytes = diff --git a/fbgemm_gpu/src/metric_ops/metric_ops.cu b/fbgemm_gpu/src/metric_ops/metric_ops.cu index 6bb867be9a..8547b87d81 100644 --- a/fbgemm_gpu/src/metric_ops/metric_ops.cu +++ b/fbgemm_gpu/src/metric_ops/metric_ops.cu @@ -244,8 +244,7 @@ at::Tensor batch_auc( block_sums = at::empty({grid_size * 2}, output_options); } - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(weights.get_device()); + CUDA_DEVICE_GUARD(weights); auto max_smem_size = at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock; diff --git a/fbgemm_gpu/src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu b/fbgemm_gpu/src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu index 7331e59c20..de69ab63da 100644 --- a/fbgemm_gpu/src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu +++ b/fbgemm_gpu/src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu @@ -70,8 +70,8 @@ Tensor permute_pooled_embs_gpu_impl( TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( pooled_embs, offset_dim_list, permute_list, inv_offset_dim_list); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(pooled_embs.get_device()); + CUDA_DEVICE_GUARD(pooled_embs); + // We couldn't pass the "pooled_embs.is_contiguous()" check in the backward // passs after D22767058. TODO: optimize and make sure pooled_embs is // contiguous. diff --git a/fbgemm_gpu/src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu b/fbgemm_gpu/src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu index 6221ea63a7..855b48724e 100644 --- a/fbgemm_gpu/src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu +++ b/fbgemm_gpu/src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu @@ -69,8 +69,8 @@ Tensor permute_pooled_embs_split_gpu_impl( TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( pooled_embs, offset_dim_list, permute_list, inv_offset_dim_list); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(pooled_embs.get_device()); + CUDA_DEVICE_GUARD(pooled_embs); + // We couldn't pass the "pooled_embs.is_contiguous()" check in the backward // passs after D22767058. TODO: optimize and make sure pooled_embs is // contiguous. diff --git a/fbgemm_gpu/src/quantize_ops/quantize_bfloat16.cu b/fbgemm_gpu/src/quantize_ops/quantize_bfloat16.cu index 4e42a9d8e4..68552abadd 100644 --- a/fbgemm_gpu/src/quantize_ops/quantize_bfloat16.cu +++ b/fbgemm_gpu/src/quantize_ops/quantize_bfloat16.cu @@ -22,8 +22,7 @@ namespace fbgemm_gpu { /// @return A new tensor with values from the input tensor converted to /// `bfloat16`. DLL_PUBLIC at::Tensor _float_to_bfloat16_gpu(const at::Tensor& input) { - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(input.get_device()); + CUDA_DEVICE_GUARD(input); // TODO: replace Half by BFloat16, after BFloat16 is supported by Nvidia // NCCL input.options().dtype(at::kBFloat16)); // at::kBFloat16 @@ -53,8 +52,7 @@ DLL_PUBLIC at::Tensor _float_to_bfloat16_gpu(const at::Tensor& input) { /// /// @return A new tensor with values from the input tensor converted to `float`. DLL_PUBLIC at::Tensor _bfloat16_to_float_gpu(const at::Tensor& input) { - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(input.get_device()); + CUDA_DEVICE_GUARD(input); auto output = at::empty({}, input.options().dtype(at::kFloat)); output.resize_(0); diff --git a/fbgemm_gpu/src/quantize_ops/quantize_fp8_rowwise.cu b/fbgemm_gpu/src/quantize_ops/quantize_fp8_rowwise.cu index 5623ddba6f..bfd1076936 100644 --- a/fbgemm_gpu/src/quantize_ops/quantize_fp8_rowwise.cu +++ b/fbgemm_gpu/src/quantize_ops/quantize_fp8_rowwise.cu @@ -196,8 +196,7 @@ template Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) { TENSOR_ON_CUDA_GPU(input); TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(input.get_device()); + CUDA_DEVICE_GUARD(input); const auto input_sizes = input.sizes(); const auto last_dim = input_sizes.size() - 1; @@ -353,9 +352,7 @@ Tensor _FP8rowwise_to_float_gpu_t( const int64_t output_dtype) { TENSOR_ON_CUDA_GPU(input); TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); - - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(input.get_device()); + CUDA_DEVICE_GUARD(input); const auto input_sizes = input.sizes(); const auto last_dim = input_sizes.size() - 1; diff --git a/fbgemm_gpu/src/quantize_ops/quantize_fused_8bit_rowwise.cu b/fbgemm_gpu/src/quantize_ops/quantize_fused_8bit_rowwise.cu index 05e2f9655c..36e1685e59 100644 --- a/fbgemm_gpu/src/quantize_ops/quantize_fused_8bit_rowwise.cu +++ b/fbgemm_gpu/src/quantize_ops/quantize_fused_8bit_rowwise.cu @@ -219,9 +219,7 @@ template Tensor _float_to_fused8bitrowwise_gpu_t(const Tensor& input) { TENSOR_ON_CUDA_GPU(input); TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); - - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(input.get_device()); + CUDA_DEVICE_GUARD(input); const auto input_sizes = input.sizes(); const auto last_dim = input_sizes.size() - 1; @@ -375,9 +373,7 @@ template Tensor _fused8bitrowwise_to_float_gpu_t(const Tensor& input) { TENSOR_ON_CUDA_GPU(input); TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); - - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(input.get_device()); + CUDA_DEVICE_GUARD(input); const auto input_sizes = input.sizes(); const auto last_dim = input_sizes.size() - 1; @@ -520,9 +516,7 @@ DLL_PUBLIC at::Tensor _fused8bitrowwise_to_float_mixed_dim_gpu( // row of each table TENSOR_CONTIGUOUS_AND_ON_CUDA_GPU(input); TENSOR_CONTIGUOUS_AND_ON_CUDA_GPU(D_offsets); - - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(input.get_device()); + CUDA_DEVICE_GUARD(input); const int64_t batch_size = input.size(0); const int qparam_size = 8; diff --git a/fbgemm_gpu/src/quantize_ops/quantize_fused_nbit_rowwise.cu b/fbgemm_gpu/src/quantize_ops/quantize_fused_nbit_rowwise.cu index 90c84caba4..8ac122c71b 100644 --- a/fbgemm_gpu/src/quantize_ops/quantize_fused_nbit_rowwise.cu +++ b/fbgemm_gpu/src/quantize_ops/quantize_fused_nbit_rowwise.cu @@ -113,9 +113,7 @@ Tensor _float_to_fusednbitrowwise_gpu_t( const int64_t bit_rate) { TENSOR_ON_CUDA_GPU(input); TENSOR_NDIM_EQUALS(input, 2); - - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(input.get_device()); + CUDA_DEVICE_GUARD(input); const int nrows = input.size(0); const int ncols = input.size(1); @@ -220,9 +218,7 @@ Tensor _fusednbitrowwise_to_float_gpu_t( const int64_t bit_rate) { TENSOR_ON_CUDA_GPU(input); TENSOR_NDIM_EQUALS(input, 2); - - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(input.get_device()); + CUDA_DEVICE_GUARD(input); const int nrows = input.size(0); const int ncols = input.size(1); diff --git a/fbgemm_gpu/src/quantize_ops/quantize_hfp8.cu b/fbgemm_gpu/src/quantize_ops/quantize_hfp8.cu index 4b7d56ed95..730bacd1cc 100644 --- a/fbgemm_gpu/src/quantize_ops/quantize_hfp8.cu +++ b/fbgemm_gpu/src/quantize_ops/quantize_hfp8.cu @@ -31,9 +31,7 @@ DLL_PUBLIC at::Tensor _float_to_hfp8_gpu( const double max_pos) { TORCH_CHECK(ebits > 0); TORCH_CHECK(exponent_bias > 0); - - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(input.get_device()); + CUDA_DEVICE_GUARD(input); auto output = at::empty({}, input.options().dtype(at::kByte)); output.resize_(0); @@ -68,9 +66,7 @@ DLL_PUBLIC at::Tensor _hfp8_to_float_gpu( const int64_t exponent_bias) { TORCH_CHECK(ebits > 0); TORCH_CHECK(exponent_bias > 0); - - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(input.get_device()); + CUDA_DEVICE_GUARD(input); auto output = at::empty({}, input.options().dtype(at::kFloat)); output.resize_(0); diff --git a/fbgemm_gpu/src/quantize_ops/quantize_msfp.cu b/fbgemm_gpu/src/quantize_ops/quantize_msfp.cu index 1416a8110d..dbada12073 100644 --- a/fbgemm_gpu/src/quantize_ops/quantize_msfp.cu +++ b/fbgemm_gpu/src/quantize_ops/quantize_msfp.cu @@ -137,8 +137,7 @@ DLL_PUBLIC at::Tensor _float_to_msfp_gpu( TORCH_CHECK(ebits > 0 && mbits > 0); TORCH_CHECK(min_pos > 0 && max_pos > 0 && max_pos > min_pos); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(input.get_device()); + CUDA_DEVICE_GUARD(input); const int nrows = input.size(0); const int ncols = input.size(1); diff --git a/fbgemm_gpu/src/quantize_ops/quantize_padded_fp8_rowwise.cu b/fbgemm_gpu/src/quantize_ops/quantize_padded_fp8_rowwise.cu index 59c95b3c13..6c1d4ef66e 100644 --- a/fbgemm_gpu/src/quantize_ops/quantize_padded_fp8_rowwise.cu +++ b/fbgemm_gpu/src/quantize_ops/quantize_padded_fp8_rowwise.cu @@ -207,9 +207,7 @@ Tensor _float_to_paddedFP8rowwise_gpu_t( const bool forward, const int64_t row_dim) { TENSOR_CONTIGUOUS_AND_ON_CUDA_GPU(input); - - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(input.get_device()); + CUDA_DEVICE_GUARD(input); const auto input_sizes = input.sizes(); const auto last_dim = input_sizes.size() - 1; @@ -265,9 +263,7 @@ Tensor _paddedFP8rowwise_to_float_gpu_t( const int64_t output_dtype) { TENSOR_ON_CUDA_GPU(input); TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); - - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(input.get_device()); + CUDA_DEVICE_GUARD(input); const auto input_sizes = input.sizes(); const auto last_dim = input_sizes.size() - 1; diff --git a/fbgemm_gpu/src/sparse_ops/sparse_async_cumsum.cu b/fbgemm_gpu/src/sparse_ops/sparse_async_cumsum.cu index c72877337d..5edcc428de 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_async_cumsum.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_async_cumsum.cu @@ -14,8 +14,7 @@ namespace fbgemm_gpu { DLL_PUBLIC Tensor asynchronous_inclusive_cumsum_gpu(const Tensor& t_in) { TENSOR_ON_CUDA_GPU(t_in); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(t_in.get_device()); + CUDA_DEVICE_GUARD(t_in); if (t_in.numel() == 0) { return at::empty_like(t_in); @@ -59,8 +58,7 @@ DLL_PUBLIC Tensor asynchronous_inclusive_cumsum_gpu(const Tensor& t_in) { DLL_PUBLIC Tensor asynchronous_exclusive_cumsum_gpu(const Tensor& t_in) { TENSOR_ON_CUDA_GPU(t_in); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(t_in.get_device()); + CUDA_DEVICE_GUARD(t_in); if (t_in.numel() == 0) { return at::empty_like(t_in); @@ -104,8 +102,8 @@ DLL_PUBLIC Tensor asynchronous_exclusive_cumsum_gpu(const Tensor& t_in) { DLL_PUBLIC Tensor asynchronous_complete_cumsum_gpu(const Tensor& t_in) { TENSOR_ON_CUDA_GPU(t_in); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(t_in.get_device()); + CUDA_DEVICE_GUARD(t_in); + size_t temp_storage_bytes = 0; TORCH_CHECK(t_in.is_contiguous()); TORCH_CHECK(t_in.dtype() == at::kInt || t_in.dtype() == at::kLong); diff --git a/fbgemm_gpu/src/sparse_ops/sparse_batched_unary_embeddings.cu b/fbgemm_gpu/src/sparse_ops/sparse_batched_unary_embeddings.cu index 4fe267224f..d6f8f548dc 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_batched_unary_embeddings.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_batched_unary_embeddings.cu @@ -54,9 +54,8 @@ Tensor batched_unary_embeddings_forward_cuda( TENSOR_CONTIGUOUS_AND_ON_CUDA_GPU(weight); TENSOR_CONTIGUOUS_AND_ON_CUDA_GPU(offsets); TENSOR_CONTIGUOUS_AND_ON_CUDA_GPU(indices); + CUDA_DEVICE_GUARD(weight); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(weight.get_device()); // N: number of tasks, T: number of tables, B: batch size const int32_t N = weight.size(0); const int32_t T = table_offsets.numel() - 1; @@ -177,8 +176,7 @@ DLL_PUBLIC Tensor batched_unary_embeddings_backward_cuda( TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( grad_output, weight, table_offsets, offsets, indices); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(grad_output.get_device()); + CUDA_DEVICE_GUARD(grad_output); // N: number of tasks, T: number of tables, B: batch size const int32_t N = grad_output.size(0); diff --git a/fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features.cu b/fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features.cu index 676ec9f4d8..596e52357a 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features.cu @@ -190,8 +190,8 @@ block_bucketize_sparse_features_cuda( const c10::optional>& block_bucketize_pos) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(lengths, indices); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(lengths.get_device()); + CUDA_DEVICE_GUARD(lengths); + // allocate tensors and buffers const auto lengths_size = lengths.numel(); const auto T = block_sizes.numel(); diff --git a/fbgemm_gpu/src/sparse_ops/sparse_bucketize_features.cu b/fbgemm_gpu/src/sparse_ops/sparse_bucketize_features.cu index 4f9cec0ec1..cc5e802c1f 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_bucketize_features.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_bucketize_features.cu @@ -93,8 +93,8 @@ bucketize_sparse_features_cuda( const c10::optional& weights) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(lengths, indices); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(lengths.get_device()); + CUDA_DEVICE_GUARD(lengths); + // allocate tensors and buffers const int lengths_size = lengths.numel(); const int new_lengths_size = lengths_size * my_size; diff --git a/fbgemm_gpu/src/sparse_ops/sparse_expand_into_jagged_permute.cu b/fbgemm_gpu/src/sparse_ops/sparse_expand_into_jagged_permute.cu index 55e6e7a1e7..3e2e9ab765 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_expand_into_jagged_permute.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_expand_into_jagged_permute.cu @@ -46,8 +46,7 @@ DLL_PUBLIC Tensor expand_into_jagged_permute_cuda( TORCH_CHECK(permute.numel() == input_offsets.numel() - 1); TORCH_CHECK(permute.numel() == output_offsets.numel() - 1); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(permute.get_device()); + CUDA_DEVICE_GUARD(permute); const auto permute_contig = permute.contiguous(); const auto permute_size = permute.numel(); diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu index 59472eafda..764b466957 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu @@ -110,8 +110,7 @@ DLL_PUBLIC void group_index_select_or_add_cuda( return; } - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(device); + at::cuda::OptionalCUDAGuard device_guard(device); // Partition work based on num_work_rows uint32_t num_warps_per_threadblock = kMaxThreads / kWarpSize; diff --git a/fbgemm_gpu/src/sparse_ops/sparse_index_add.cu b/fbgemm_gpu/src/sparse_ops/sparse_index_add.cu index 8041262287..2c43274ffe 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_index_add.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_index_add.cu @@ -83,8 +83,7 @@ DLL_PUBLIC Tensor index_add_with_unique_indices_cuda( std::vector& input_shape, const int consecutive_range_start, const int consecutive_range_length) { - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(grad_output.get_device()); + CUDA_DEVICE_GUARD(grad_output); const int N = grad_output.size(0); diff --git a/fbgemm_gpu/src/sparse_ops/sparse_index_select.cu b/fbgemm_gpu/src/sparse_ops/sparse_index_select.cu index f74b2d4f1d..9e6605b259 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_index_select.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_index_select.cu @@ -53,8 +53,7 @@ DLL_PUBLIC Tensor index_select_cuda( const Tensor& indices, const Tensor& orig_indices, const bool indices_sorted) { - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(input.get_device()); + CUDA_DEVICE_GUARD(input); const int N = indices.size(0); auto output_shape = input.sizes().vec(); diff --git a/fbgemm_gpu/src/sparse_ops/sparse_invert_permute.cu b/fbgemm_gpu/src/sparse_ops/sparse_invert_permute.cu index 76695029b0..4532fd3fc3 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_invert_permute.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_invert_permute.cu @@ -24,8 +24,8 @@ __global__ __launch_bounds__(kMaxThreads) void invert_permute_kernel( DLL_PUBLIC Tensor invert_permute_cuda(const Tensor& permute) { TENSOR_ON_CUDA_GPU(permute); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(permute.get_device()); + CUDA_DEVICE_GUARD(permute); + const auto permute_contig = permute.contiguous(); const auto permute_size = permute.numel(); Tensor inversed_permute = at::empty_like(permute); diff --git a/fbgemm_gpu/src/sparse_ops/sparse_pack_segments_backward.cu b/fbgemm_gpu/src/sparse_ops/sparse_pack_segments_backward.cu index 4d4cda7219..a102b6bc54 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_pack_segments_backward.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_pack_segments_backward.cu @@ -60,8 +60,7 @@ DLL_PUBLIC Tensor pack_segments_backward_cuda( max_length == data.size(1), "max_length should be equal to the second dimension of the packed segments"); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(data.get_device()); + CUDA_DEVICE_GUARD(data); Tensor unpacked_tensor; // The output tensor diff --git a/fbgemm_gpu/src/sparse_ops/sparse_pack_segments_forward.cu b/fbgemm_gpu/src/sparse_ops/sparse_pack_segments_forward.cu index ec989cca9b..a0eb9e6966 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_pack_segments_forward.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_pack_segments_forward.cu @@ -67,8 +67,7 @@ DLL_PUBLIC Tensor pack_segments_forward_cuda( "t_in must be of type float or double or half or bfloat16"); TORCH_CHECK_GT(max_length, 0); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(t_in.get_device()); + CUDA_DEVICE_GUARD(t_in); const auto t_in_c = t_in.contiguous(); diff --git a/fbgemm_gpu/src/sparse_ops/sparse_permute102.cu b/fbgemm_gpu/src/sparse_ops/sparse_permute102.cu index 0f50c8e581..505964c5c8 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_permute102.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_permute102.cu @@ -36,8 +36,7 @@ DLL_PUBLIC Tensor permute102_baddbmm_permute102_cuda( TENSOR_CONTIGUOUS_AND_ON_CUDA_GPU(B); TENSOR_CONTIGUOUS_AND_ON_CUDA_GPU(bias); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(A.get_device()); + CUDA_DEVICE_GUARD(A); TENSORS_ON_SAME_DEVICE(A, B); TENSORS_ON_SAME_DEVICE(A, bias); diff --git a/fbgemm_gpu/src/sparse_ops/sparse_permute_1d.cu b/fbgemm_gpu/src/sparse_ops/sparse_permute_1d.cu index 145d279c0d..5b90a85e54 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_permute_1d.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_permute_1d.cu @@ -70,8 +70,7 @@ permute_1D_sparse_data_cuda( const c10::optional& permuted_lengths_sum) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(permute, lengths, indices, weights); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(indices.get_device()); + CUDA_DEVICE_GUARD(indices); const auto permute_contig = permute.contiguous(); const auto lengths_contig = lengths.contiguous(); diff --git a/fbgemm_gpu/src/sparse_ops/sparse_permute_2d.cu b/fbgemm_gpu/src/sparse_ops/sparse_permute_2d.cu index 24086e0031..4642d291a7 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_permute_2d.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_permute_2d.cu @@ -77,8 +77,7 @@ permute_2D_sparse_data_cuda( TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(permute, lengths, indices, weights); TORCH_CHECK(lengths.dim() == 2); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(indices.get_device()); + CUDA_DEVICE_GUARD(indices); const auto permute_contig = permute.contiguous(); const auto lengths_contig = lengths.contiguous(); @@ -241,8 +240,7 @@ permute_sparse_features_cuda( const c10::optional& weights) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(permute, lengths, indices, weights); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(indices.get_device()); + CUDA_DEVICE_GUARD(indices); // the following implementation requires lengths and indices has the same // dtype if usecase comes up that requires different dtype (e.g. int32 for diff --git a/fbgemm_gpu/src/sparse_ops/sparse_permute_embeddings.cu b/fbgemm_gpu/src/sparse_ops/sparse_permute_embeddings.cu index 908d036eb7..3664d02cee 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_permute_embeddings.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_permute_embeddings.cu @@ -52,8 +52,7 @@ DLL_PUBLIC std::tuple permute_sequence_embeddings_cuda( // wrapper for permute_2D_sparse_data_cuda, kept for BC TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(permute, lengths, embeddings); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(embeddings.get_device()); + CUDA_DEVICE_GUARD(embeddings); TORCH_CHECK( lengths.dim() == 2, diff --git a/fbgemm_gpu/src/sparse_ops/sparse_range.cu b/fbgemm_gpu/src/sparse_ops/sparse_range.cu index b639cc1972..94a90699fc 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_range.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_range.cu @@ -75,8 +75,7 @@ offsets_range_cuda(const Tensor& offsets, int64_t range_size) { TENSOR_ON_CUDA_GPU(offsets); TENSOR_NDIM_EQUALS(offsets, 1); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(offsets.get_device()); + CUDA_DEVICE_GUARD(offsets); auto offsets_arg = at::TensorArg(offsets, "offsets", 1); checkScalarTypes("_offsets_range_cuda", offsets_arg, {at::kLong, at::kInt}); @@ -113,8 +112,7 @@ DLL_PUBLIC Tensor lengths_range_cuda( TENSOR_ON_CUDA_GPU(t_in); TENSOR_NDIM_EQUALS(t_in, 1); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(t_in.get_device()); + CUDA_DEVICE_GUARD(t_in); const auto t_in_contig = t_in.contiguous(); const auto num_seq = t_in_contig.numel(); diff --git a/fbgemm_gpu/src/sparse_ops/sparse_reorder_batched_ad.cu b/fbgemm_gpu/src/sparse_ops/sparse_reorder_batched_ad.cu index dc55104229..898640b8f8 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_reorder_batched_ad.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_reorder_batched_ad.cu @@ -55,8 +55,7 @@ DLL_PUBLIC Tensor reorder_batched_ad_lengths_gpu( const bool broadcast_lengths) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(cat_ad_lengths, batch_offsets); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(cat_ad_lengths.get_device()); + CUDA_DEVICE_GUARD(cat_ad_lengths); const int64_t B = batch_offsets.numel() - 1; const int64_t T = broadcast_lengths @@ -190,8 +189,7 @@ DLL_PUBLIC Tensor reorder_batched_ad_indices_gpu( TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( cat_ad_offsets, cat_ad_indices, reordered_cat_ad_offsets, batch_offsets); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(cat_ad_offsets.get_device()); + CUDA_DEVICE_GUARD(cat_ad_offsets); const int64_t B = batch_offsets.numel() - 1; const int64_t T = (reordered_cat_ad_offsets.numel() - 1) / num_ads_in_batch; diff --git a/fbgemm_gpu/src/sparse_ops/sparse_segment_sum_csr.cu b/fbgemm_gpu/src/sparse_ops/sparse_segment_sum_csr.cu index 4db734af80..e2b0a30656 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_segment_sum_csr.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_segment_sum_csr.cu @@ -55,8 +55,7 @@ DLL_PUBLIC Tensor segment_sum_csr_cuda( const Tensor& values) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(csr_seg, values); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(values.get_device()); + CUDA_DEVICE_GUARD(values); auto output = at::empty(csr_seg.numel() - 1, values.options()); constexpr uint32_t threads_per_block = 256; diff --git a/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_find.cu b/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_find.cu index 0382de8e96..7fe594d1a9 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_find.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_find.cu @@ -82,8 +82,7 @@ void lfu_update_counts_cuda( TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( unique_indices, unique_indices_length, unique_indices_count, lfu_state); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(unique_indices.get_device()); + CUDA_DEVICE_GUARD(unique_indices); const int32_t N = unique_indices.size(0); AT_DISPATCH_INDEX_TYPES( @@ -115,8 +114,7 @@ std::pair lfu_cache_find_uncached_cuda( TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( unique_indices, unique_indices_length, lxu_cache_state, lfu_state); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(unique_indices.get_device()); + CUDA_DEVICE_GUARD(unique_indices); auto cache_sets = full_like( unique_indices, diff --git a/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate.cu b/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate.cu index 76cbe6d440..63df6a09f6 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate.cu @@ -187,8 +187,7 @@ void lfu_cache_insert_cuda( lxu_cache_weights, lfu_state); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(weights.get_device()); + CUDA_DEVICE_GUARD(weights); const int32_t N = cache_set_sorted_unique_indices.numel(); @@ -262,8 +261,7 @@ DLL_PUBLIC void lfu_cache_populate_cuda( lxu_cache_weights, lfu_state); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(weights.get_device()); + CUDA_DEVICE_GUARD(weights); TORCH_CHECK( linear_cache_indices.numel() < std::numeric_limits::max()); diff --git a/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate_byte.cu b/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate_byte.cu index b3906d844c..be659ecd2f 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate_byte.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate_byte.cu @@ -165,8 +165,7 @@ void lfu_cache_insert_byte_cuda( lxu_cache_weights, lfu_state); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(weights.get_device()); + CUDA_DEVICE_GUARD(weights); const int32_t N = cache_set_sorted_unique_indices.numel(); @@ -231,8 +230,7 @@ DLL_PUBLIC void lfu_cache_populate_byte_cuda( lxu_cache_weights, lfu_state); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(weights.get_device()); + CUDA_DEVICE_GUARD(weights); TORCH_CHECK( linear_cache_indices.numel() < std::numeric_limits::max()); diff --git a/fbgemm_gpu/src/split_embeddings_cache/linearize_cache_indices.cu b/fbgemm_gpu/src/split_embeddings_cache/linearize_cache_indices.cu index 9bb6fe5c86..6b508e74fd 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/linearize_cache_indices.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/linearize_cache_indices.cu @@ -63,8 +63,7 @@ DLL_PUBLIC Tensor linearize_cache_indices_cuda( TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( cache_hash_size_cumsum, indices, offsets); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(cache_hash_size_cumsum.get_device()); + CUDA_DEVICE_GUARD(cache_hash_size_cumsum); const auto T = cache_hash_size_cumsum.size(0) - 1; TORCH_CHECK(T > 0); @@ -146,8 +145,7 @@ DLL_PUBLIC Tensor linearize_cache_indices_from_row_idx_cuda( TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( cache_hash_size_cumsum, update_table_indices, update_row_indices); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(cache_hash_size_cumsum.get_device()); + CUDA_DEVICE_GUARD(cache_hash_size_cumsum); const auto T = cache_hash_size_cumsum.size(0) - 1; TORCH_CHECK(T > 0); @@ -188,8 +186,7 @@ get_unique_indices_cuda( bool compute_count) { TENSOR_ON_CUDA_GPU(linear_indices); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(linear_indices.get_device()); + CUDA_DEVICE_GUARD(linear_indices); TORCH_CHECK(linear_indices.numel() < std::numeric_limits::max()); const int32_t N = linear_indices.numel(); diff --git a/fbgemm_gpu/src/split_embeddings_cache/lru_cache_find.cu b/fbgemm_gpu/src/split_embeddings_cache/lru_cache_find.cu index a3c4926624..3ea4cd3498 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/lru_cache_find.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/lru_cache_find.cu @@ -47,8 +47,7 @@ DLL_PUBLIC Tensor emulate_cache_miss( TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( lxu_cache_locations, uvm_cache_stats); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(lxu_cache_locations.get_device()); + CUDA_DEVICE_GUARD(lxu_cache_locations); const auto N = lxu_cache_locations.numel(); if (N == 0) { @@ -170,8 +169,7 @@ DLL_PUBLIC std::pair lru_cache_find_uncached_cuda( uvm_cache_stats, lxu_cache_locking_counter); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(unique_indices.get_device()); + CUDA_DEVICE_GUARD(unique_indices); // Fill with sentinel value auto cache_sets = full_like( diff --git a/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate.cu b/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate.cu index 13896890ae..631dd7ef94 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate.cu @@ -209,8 +209,7 @@ void lru_cache_insert_cuda( uvm_cache_stats, lxu_cache_locking_counter); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(weights.get_device()); + CUDA_DEVICE_GUARD(weights); const int32_t N = cache_set_sorted_unique_indices.numel(); @@ -317,8 +316,7 @@ DLL_PUBLIC void lru_cache_populate_cuda( TENSOR_ON_CUDA_GPU(lxu_cache_locking_counter_); } - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(weights.get_device()); + CUDA_DEVICE_GUARD(weights); TORCH_CHECK( linear_cache_indices.numel() < std::numeric_limits::max()); diff --git a/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate_byte.cu b/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate_byte.cu index 40be037da1..fdb052d009 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate_byte.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate_byte.cu @@ -104,8 +104,7 @@ Tensor direct_mapped_lru_cache_find_uncached_cuda( lru_state, lxu_cache_miss_timestamp); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(linear_cache_indices.get_device()); + CUDA_DEVICE_GUARD(linear_cache_indices); const int32_t N = linear_cache_indices.numel(); @@ -391,8 +390,7 @@ void lru_cache_insert_byte_cuda( lru_state, uvm_cache_stats); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(weights.get_device()); + CUDA_DEVICE_GUARD(weights); const int32_t N = cache_set_sorted_unique_indices.numel(); @@ -463,8 +461,7 @@ void direct_mapped_lru_cache_insert_byte_cuda( linear_cache_indices, lxu_cache_miss_timestamp); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(weights.get_device()); + CUDA_DEVICE_GUARD(weights); const int32_t N = cache_sets.size(0); @@ -542,8 +539,7 @@ DLL_PUBLIC void lru_cache_populate_byte_cuda( TENSOR_ON_CUDA_GPU(uvm_cache_stats_); } - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(weights.get_device()); + CUDA_DEVICE_GUARD(weights); TORCH_CHECK( linear_cache_indices.numel() < std::numeric_limits::max()); @@ -635,8 +631,7 @@ DLL_PUBLIC void direct_mapped_lru_cache_populate_byte_cuda( auto uvm_cache_stats_ = uvm_cache_stats.value_or( at::empty({0}, weights.options().dtype(at::kInt))); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(weights.get_device()); + CUDA_DEVICE_GUARD(weights); TORCH_CHECK( linear_cache_indices.numel() < std::numeric_limits::max()); diff --git a/fbgemm_gpu/src/split_embeddings_cache/lxu_cache.cu b/fbgemm_gpu/src/split_embeddings_cache/lxu_cache.cu index 3e39eb9e23..088e911e6b 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/lxu_cache.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/lxu_cache.cu @@ -104,8 +104,7 @@ DLL_PUBLIC void lxu_cache_flush_cuda( lxu_cache_state, lxu_cache_weights); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(lxu_cache_weights.get_device()); + CUDA_DEVICE_GUARD(lxu_cache_weights); const int32_t T = D_offsets.numel() - 1; const int32_t S = lxu_cache_weights.size(0); @@ -194,8 +193,7 @@ void lxu_cache_locking_counter_decrement_cuda( TENSOR_ON_CUDA_GPU(lxu_cache_locking_counter); TENSOR_ON_CUDA_GPU(lxu_cache_locations); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(lxu_cache_locations.get_device()); + CUDA_DEVICE_GUARD(lxu_cache_locations); const auto N = lxu_cache_locations.numel(); if (N == 0) { @@ -427,8 +425,7 @@ DLL_PUBLIC Tensor lxu_cache_lookup_cuda( uvm_cache_stats_ = uvm_cache_stats.value(); } - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(linear_cache_indices.get_device()); + CUDA_DEVICE_GUARD(linear_cache_indices); const auto lxu_cache_locations = lxu_cache_locations_output.value_or(empty_like( @@ -484,8 +481,7 @@ DLL_PUBLIC Tensor direct_mapped_lxu_cache_lookup_cuda( auto uvm_cache_stats_ = uvm_cache_stats.value_or( at::empty({0}, linear_cache_indices.options().dtype(at::kInt))); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(linear_cache_indices.get_device()); + CUDA_DEVICE_GUARD(linear_cache_indices); const auto N = linear_cache_indices.numel(); auto lxu_cache_locations = empty_like( @@ -549,8 +545,7 @@ DLL_PUBLIC void lxu_cache_locations_update_cuda( TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( lxu_cache_locations, lxu_cache_locations_new, num_uniq_cache_indices); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(lxu_cache_locations.get_device()); + CUDA_DEVICE_GUARD(lxu_cache_locations); const auto N = lxu_cache_locations.numel(); diff --git a/fbgemm_gpu/src/split_embeddings_cache/reset_weight_momentum.cu b/fbgemm_gpu/src/split_embeddings_cache/reset_weight_momentum.cu index 321c225772..a4e6c6bd67 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/reset_weight_momentum.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/reset_weight_momentum.cu @@ -234,8 +234,7 @@ DLL_PUBLIC void reset_weight_momentum_cuda( buffer_ids, cache_hash_size_cumsum, lxu_cache_state); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(dev_weights.get_device()); + CUDA_DEVICE_GUARD(dev_weights); const int64_t num_pruned_indices = pruned_indices.size(0); const int32_t num_pruned_tables = buffer_ids.size(0); diff --git a/fbgemm_gpu/src/split_embeddings_utils/generate_vbe_metadata.cu b/fbgemm_gpu/src/split_embeddings_utils/generate_vbe_metadata.cu index ea00786333..83b2fa3c7d 100644 --- a/fbgemm_gpu/src/split_embeddings_utils/generate_vbe_metadata.cu +++ b/fbgemm_gpu/src/split_embeddings_utils/generate_vbe_metadata.cu @@ -130,8 +130,7 @@ generate_vbe_metadata( TORCH_CHECK(B_offsets_rank_per_feature.size(0) == T); TORCH_CHECK(output_offsets_feature_rank.numel() == num_ranks * T + 1); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(B_offsets.get_device()); + CUDA_DEVICE_GUARD(B_offsets); Tensor row_output_offsets = at::empty({total_B}, output_offsets_feature_rank.options()); diff --git a/fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu b/fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu index d1eb5e00a1..6677be4f0d 100644 --- a/fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu +++ b/fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu @@ -22,8 +22,8 @@ using Tensor = at::Tensor; using namespace fbgemm_gpu; inline at::Tensor asynchronous_complete_cumsum(at::Tensor t_in) { - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(t_in.get_device()); + CUDA_DEVICE_GUARD(t_in); + size_t temp_storage_bytes = 0; TORCH_CHECK(t_in.is_contiguous()); TORCH_CHECK(t_in.dtype() == at::kInt || t_in.dtype() == at::kLong); diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu index 770d2eb2ae..614281c612 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu @@ -81,8 +81,8 @@ Tensor masked_index_put_cuda( Tensor count) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(self, indices, values, count); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(self.get_device()); + CUDA_DEVICE_GUARD(self); + const auto N = indices.numel(); if (N == 0) { return self; @@ -224,8 +224,7 @@ std::tuple ssd_cache_populate_actions_cuda( TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( linear_indices, lxu_cache_state, lru_state); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(linear_indices.get_device()); + CUDA_DEVICE_GUARD(linear_indices); // Get unique indices Tensor unique_indices;