Skip to content

Commit 807813a

Browse files
committed
support non contiguous input tensors for other sampling kernels
Signed-off-by: Cindy Zhang <cindyzyx9@gmail.com>
1 parent 0fe8f70 commit 807813a

File tree

3 files changed

+81
-64
lines changed

3 files changed

+81
-64
lines changed

csrc/renorm.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ using tvm::ffi::Optional;
2323

2424
void top_p_renorm_probs(TensorView probs, TensorView renorm_probs,
2525
Optional<TensorView> maybe_top_p_arr, double top_p_val) {
26-
CHECK_INPUT(probs);
26+
CHECK_CUDA(probs);
27+
CHECK_LAST_DIM_CONTIGUOUS(probs);
2728
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
2829
unsigned int batch_size = probs->shape[0];
2930
unsigned int vocab_size = probs->shape[1];
@@ -34,7 +35,7 @@ void top_p_renorm_probs(TensorView probs, TensorView renorm_probs,
3435
cudaError_t status = sampling::TopPRenormProb<float>(
3536
static_cast<float*>(probs->data), static_cast<float*>(renorm_probs->data),
3637
has_top_p_arr ? static_cast<float*>(maybe_top_p_arr.value()->data) : nullptr, batch_size,
37-
top_p_val, vocab_size, stream);
38+
top_p_val, vocab_size, probs_strides[0], stream);
3839
TVM_FFI_ICHECK(status == cudaSuccess)
3940
<< "TopPRenormProb failed with error code " << cudaGetErrorString(status);
4041
}

csrc/sampling.cu

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ void softmax(TensorView workspace_buffer, TensorView logits, TensorView output,
4545

4646
void sampling_from_logits(TensorView logits, TensorView output, Optional<TensorView> maybe_indices,
4747
bool deterministic, uint64_t philox_seed, uint64_t philox_offset) {
48-
CHECK_INPUT(logits);
48+
CHECK_CUDA(logits);
49+
CHECK_LAST_DIM_CONTIGUOUS(logits);
4950
CHECK_DIM(2, logits); // logits: (batch_size, vocab_size)
5051
unsigned int batch_size = output->shape[0];
5152
unsigned int vocab_size = logits->shape[1];
@@ -55,14 +56,16 @@ void sampling_from_logits(TensorView logits, TensorView output, Optional<TensorV
5556
cudaError_t status = sampling::SamplingFromLogits(
5657
static_cast<float*>(logits->data), static_cast<int*>(output->data),
5758
maybe_indices.has_value() ? static_cast<int*>(maybe_indices.value()->data) : nullptr,
58-
batch_size, vocab_size, deterministic, philox_seed, philox_offset, stream);
59+
batch_size, vocab_size, logits->strides[0], deterministic, philox_seed, philox_offset,
60+
stream);
5961
TVM_FFI_ICHECK(status == cudaSuccess)
6062
<< "SamplingFromLogits failed with error code " << cudaGetErrorString(status);
6163
}
6264

6365
void sampling_from_probs(TensorView probs, TensorView output, Optional<TensorView> maybe_indices,
6466
bool deterministic, uint64_t philox_seed, uint64_t philox_offset) {
65-
CHECK_INPUT(probs);
67+
CHECK_CUDA(probs);
68+
CHECK_LAST_DIM_CONTIGUOUS(probs);
6669
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
6770
unsigned int batch_size = output->shape[0];
6871
unsigned int vocab_size = probs->shape[1];
@@ -72,7 +75,7 @@ void sampling_from_probs(TensorView probs, TensorView output, Optional<TensorVie
7275
cudaError_t status = sampling::SamplingFromProb(
7376
static_cast<float*>(probs->data), static_cast<int*>(output->data),
7477
maybe_indices.has_value() ? static_cast<int*>(maybe_indices.value()->data) : nullptr,
75-
batch_size, vocab_size, deterministic, philox_seed, philox_offset, stream);
78+
batch_size, vocab_size, probs->strides[0], deterministic, philox_seed, philox_offset, stream);
7679
TVM_FFI_ICHECK(status == cudaSuccess)
7780
<< "SamplingFromProbs failed with error code " << cudaGetErrorString(status);
7881
}
@@ -81,7 +84,8 @@ void top_p_sampling_from_probs(TensorView probs, TensorView output,
8184
Optional<TensorView> maybe_indices,
8285
Optional<TensorView> maybe_top_p_arr, double top_p_val,
8386
bool deterministic, uint64_t philox_seed, uint64_t philox_offset) {
84-
CHECK_INPUT(probs);
87+
CHECK_CUDA(probs);
88+
CHECK_LAST_DIM_CONTIGUOUS(probs);
8589
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
8690
unsigned int batch_size = output->shape[0];
8791
unsigned int vocab_size = probs->shape[1];
@@ -93,7 +97,7 @@ void top_p_sampling_from_probs(TensorView probs, TensorView output,
9397
static_cast<float*>(probs->data), static_cast<int*>(output->data),
9498
maybe_indices.has_value() ? static_cast<int*>(maybe_indices.value()->data) : nullptr,
9599
has_top_p_arr ? static_cast<float*>(maybe_top_p_arr.value()->data) : nullptr, batch_size,
96-
top_p_val, vocab_size, deterministic, philox_seed, philox_offset, stream);
100+
top_p_val, vocab_size, probs->strides[0], deterministic, philox_seed, philox_offset, stream);
97101
TVM_FFI_ICHECK(status == cudaSuccess)
98102
<< "TopPSamplingFromProbs failed with error code " << cudaGetErrorString(status);
99103
}
@@ -102,7 +106,8 @@ void top_k_sampling_from_probs(TensorView probs, TensorView output,
102106
Optional<TensorView> maybe_indices,
103107
Optional<TensorView> maybe_top_k_arr, int64_t top_k_val,
104108
bool deterministic, uint64_t philox_seed, uint64_t philox_offset) {
105-
CHECK_INPUT(probs);
109+
CHECK_CUDA(probs);
110+
CHECK_LAST_DIM_CONTIGUOUS(probs);
106111
CHECK_INPUT(output);
107112
CHECK_DEVICE(output, probs);
108113
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
@@ -117,7 +122,7 @@ void top_k_sampling_from_probs(TensorView probs, TensorView output,
117122
static_cast<float*>(probs->data), static_cast<int*>(output->data),
118123
maybe_indices.has_value() ? static_cast<int*>(maybe_indices.value()->data) : nullptr,
119124
has_top_k_arr ? static_cast<float*>(maybe_top_k_arr.value()->data) : nullptr, batch_size,
120-
top_k_val, vocab_size, deterministic, philox_seed, philox_offset, stream);
125+
top_k_val, vocab_size, probs->strides[0], deterministic, philox_seed, philox_offset, stream);
121126
TVM_FFI_ICHECK(status == cudaSuccess)
122127
<< "TopKSamplingFromProbs failed with error code " << cudaGetErrorString(status);
123128
}
@@ -126,7 +131,8 @@ void min_p_sampling_from_probs(TensorView probs, TensorView output,
126131
Optional<TensorView> maybe_indices,
127132
Optional<TensorView> maybe_min_p_arr, double min_p_val,
128133
bool deterministic, uint64_t philox_seed, uint64_t philox_offset) {
129-
CHECK_INPUT(probs);
134+
CHECK_CUDA(probs);
135+
CHECK_LAST_DIM_CONTIGUOUS(probs);
130136
CHECK_INPUT(output);
131137
CHECK_DEVICE(output, probs);
132138
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
@@ -142,7 +148,8 @@ void min_p_sampling_from_probs(TensorView probs, TensorView output,
142148
has_min_p_arr ? static_cast<float*>(maybe_min_p_arr.value()->data) : nullptr,
143149
static_cast<int*>(output->data),
144150
maybe_indices.has_value() ? static_cast<int*>(maybe_indices.value()->data) : nullptr,
145-
batch_size, min_p_val, vocab_size, deterministic, philox_seed, philox_offset, stream);
151+
batch_size, min_p_val, vocab_size, probs->strides[0], deterministic, philox_seed,
152+
philox_offset, stream);
146153
TVM_FFI_ICHECK(status == cudaSuccess)
147154
<< "MinPSamplingFromProb failed with error code " << cudaGetErrorString(status);
148155
}
@@ -153,7 +160,8 @@ void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output,
153160
Optional<TensorView> maybe_top_p_arr, double top_p_val,
154161
bool deterministic, uint64_t philox_seed,
155162
uint64_t philox_offset) {
156-
CHECK_INPUT(probs);
163+
CHECK_CUDA(probs);
164+
CHECK_LAST_DIM_CONTIGUOUS(probs);
157165
CHECK_INPUT(output);
158166
CHECK_DEVICE(output, probs);
159167
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
@@ -171,8 +179,8 @@ void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output,
171179
has_top_p_arr ? static_cast<float*>(maybe_top_p_arr.value()->data) : nullptr,
172180
static_cast<int*>(output->data),
173181
maybe_indices.has_value() ? static_cast<int*>(maybe_indices.value()->data) : nullptr,
174-
batch_size, top_k_val, top_p_val, vocab_size, deterministic, philox_seed, philox_offset,
175-
stream);
182+
batch_size, top_k_val, top_p_val, vocab_size, probs->strides[0], deterministic, philox_seed,
183+
philox_offset, stream);
176184
TVM_FFI_ICHECK(status == cudaSuccess)
177185
<< "TopKTopPSamplingFromProbs failed with error code " << cudaGetErrorString(status);
178186
}

0 commit comments

Comments
 (0)