Skip to content

Commit

Permalink
fix out of bound error (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
wejoncy authored Sep 25, 2024
1 parent fd08d0f commit 535dcb1
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 2 deletions.
26 changes: 26 additions & 0 deletions csrc/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,32 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>

class OptionalCUDAGuard {
int set_device_ = -1;
int current_device_ = -1;

public:
OptionalCUDAGuard(int device) : set_device_(device) {
cudaError_t err = cudaGetDevice(&current_device_);
std::stringstream ss;
if (err != cudaSuccess) {
ss << "cudaGetDevice failed with error code " << cudaGetErrorString(err);
TORCH_CHECK(err == cudaSuccess, ss.str());
}
if (current_device_ == device) {
return;
}
err = cudaSetDevice(device);
if (err != cudaSuccess) {
ss << "cudaGetDevice failed with error code " << cudaGetErrorString(err);
TORCH_CHECK(err == cudaSuccess, ss.str());
}
}
~OptionalCUDAGuard() {
if (set_device_ != current_device_) cudaSetDevice(current_device_);
}
};

#define gpuErrchk(ret) gpuAssert((ret), __FILE__, __LINE__);

inline void gpuAssert(cudaError_t code, const char* file, int line) {
Expand Down
2 changes: 2 additions & 0 deletions csrc/dequant_impl_packed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ torch::Tensor lauch_deqantize_outliers_cuda_packkernel(
const c10::optional<torch::Tensor>& outliers_indices, //[num_cen, c_size, ol_in_f]
const c10::optional<torch::Tensor>& outliers_centroids, //[num_c, c_size, out_vec_len]
const c10::optional<torch::Tensor>& perm, const torch::Tensor& weight_scale, const torch::Tensor& weight_bias) {
OptionalCUDAGuard cudaguard(q_indice.device().index());
int base_groupsize = centroids.size(-1); // how many elements in a vector
int res_groupsize = residual_centroids.has_value() ? residual_centroids.value().size(-1) : 0;
// TORCH_CHECK((res_groupsize===base_groupsize||res_groupsize==0), "res_groupsize===base_groupsize is false, must be
Expand Down Expand Up @@ -443,6 +444,7 @@ torch::Tensor lauch_gemv_outliers_cuda_packkernel(
const c10::optional<torch::Tensor>& outliers_centroids, //[num_c, c_size, out_vec_len]
const c10::optional<torch::Tensor>& perm, const torch::Tensor& weight_scale, const torch::Tensor& weight_bias,
const c10::optional<torch::Tensor>& bias) {
OptionalCUDAGuard cudaguard(input.device().index());
const int base_groupsize = centroids.size(-1);
int index_bits = log2(centroids.size(1));
int res_index_bits = residual_centroids.has_value() ? log2(residual_centroids.value().size(1)) : 0;
Expand Down
5 changes: 3 additions & 2 deletions csrc/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,12 @@ __device__ __forceinline__ uint32_t iterator_packed_tensor(const uint32_t* ptr,
int second = end_bits / 32;
start_bits = start_bits % 32;
end_bits = end_bits % 32;
uint32_t sec_v = ptr[second];
uint32_t v = (ptr[first] >> (start_bits)) & ((1 << WBITS) - 1);
if (first == second) {
if (first == second || end_bits == 0) {
return v;
} else {
// second position might be out of bound
uint32_t sec_v = ptr[second];
v |= ((sec_v) & ((1 << (end_bits)) - 1)) << (32 - start_bits);
return v;
}
Expand Down

0 comments on commit 535dcb1

Please sign in to comment.