Skip to content

Commit 3192d53

Browse files
committed
bugfix: output tensor is assumed contiguous, so we should use d not stride
Signed-off-by: Cindy Zhang <cindyzyx9@gmail.com>
1 parent 595de84 commit 3192d53

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

include/flashinfer/sampling.cuh

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1862,8 +1862,7 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType
18621862
(logits_vec[j] > pivot) ? logits_vec[j] : -cuda::std::numeric_limits<float>::infinity();
18631863
}
18641864
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
1865-
logits_vec.store(masked_logits + row_idx * stride + i * BLOCK_THREADS * VEC_SIZE +
1866-
tx * VEC_SIZE);
1865+
logits_vec.store(masked_logits + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
18671866
}
18681867
}
18691868
}
@@ -1987,8 +1986,7 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
19871986
probs_vec[j] = (probs_vec[j] > pivot) ? probs_vec[j] * normalizer : 0;
19881987
}
19891988
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
1990-
probs_vec.store(renormed_prob + row_idx * stride + i * BLOCK_THREADS * VEC_SIZE +
1991-
tx * VEC_SIZE);
1989+
probs_vec.store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
19921990
}
19931991
}
19941992
}

0 commit comments

Comments
 (0)