@@ -45,7 +45,8 @@ void softmax(TensorView workspace_buffer, TensorView logits, TensorView output,
4545
4646void sampling_from_logits (TensorView logits, TensorView output, Optional<TensorView> maybe_indices,
4747 bool deterministic, uint64_t philox_seed, uint64_t philox_offset) {
48- CHECK_INPUT (logits);
48+ CHECK_CUDA (logits);
49+ CHECK_LAST_DIM_CONTIGUOUS (logits);
4950 CHECK_DIM (2 , logits); // logits: (batch_size, vocab_size)
5051 unsigned int batch_size = output->shape [0 ];
5152 unsigned int vocab_size = logits->shape [1 ];
@@ -55,14 +56,16 @@ void sampling_from_logits(TensorView logits, TensorView output, Optional<TensorV
5556 cudaError_t status = sampling::SamplingFromLogits (
5657 static_cast <float *>(logits->data ), static_cast <int *>(output->data ),
5758 maybe_indices.has_value () ? static_cast <int *>(maybe_indices.value ()->data ) : nullptr ,
58- batch_size, vocab_size, deterministic, philox_seed, philox_offset, stream);
59+ batch_size, vocab_size, logits->strides [0 ], deterministic, philox_seed, philox_offset,
60+ stream);
5961 TVM_FFI_ICHECK (status == cudaSuccess)
6062 << " SamplingFromLogits failed with error code " << cudaGetErrorString (status);
6163}
6264
6365void sampling_from_probs (TensorView probs, TensorView output, Optional<TensorView> maybe_indices,
6466 bool deterministic, uint64_t philox_seed, uint64_t philox_offset) {
65- CHECK_INPUT (probs);
67+ CHECK_CUDA (probs);
68+ CHECK_LAST_DIM_CONTIGUOUS (probs);
6669 CHECK_DIM (2 , probs); // probs: (batch_size, vocab_size)
6770 unsigned int batch_size = output->shape [0 ];
6871 unsigned int vocab_size = probs->shape [1 ];
@@ -72,7 +75,7 @@ void sampling_from_probs(TensorView probs, TensorView output, Optional<TensorVie
7275 cudaError_t status = sampling::SamplingFromProb (
7376 static_cast <float *>(probs->data ), static_cast <int *>(output->data ),
7477 maybe_indices.has_value () ? static_cast <int *>(maybe_indices.value ()->data ) : nullptr ,
75- batch_size, vocab_size, deterministic, philox_seed, philox_offset, stream);
78+ batch_size, vocab_size, probs-> strides [ 0 ], deterministic, philox_seed, philox_offset, stream);
7679 TVM_FFI_ICHECK (status == cudaSuccess)
7780 << " SamplingFromProbs failed with error code " << cudaGetErrorString (status);
7881}
@@ -81,7 +84,8 @@ void top_p_sampling_from_probs(TensorView probs, TensorView output,
8184 Optional<TensorView> maybe_indices,
8285 Optional<TensorView> maybe_top_p_arr, double top_p_val,
8386 bool deterministic, uint64_t philox_seed, uint64_t philox_offset) {
84- CHECK_INPUT (probs);
87+ CHECK_CUDA (probs);
88+ CHECK_LAST_DIM_CONTIGUOUS (probs);
8589 CHECK_DIM (2 , probs); // probs: (batch_size, vocab_size)
8690 unsigned int batch_size = output->shape [0 ];
8791 unsigned int vocab_size = probs->shape [1 ];
@@ -93,7 +97,7 @@ void top_p_sampling_from_probs(TensorView probs, TensorView output,
9397 static_cast <float *>(probs->data ), static_cast <int *>(output->data ),
9498 maybe_indices.has_value () ? static_cast <int *>(maybe_indices.value ()->data ) : nullptr ,
9599 has_top_p_arr ? static_cast <float *>(maybe_top_p_arr.value ()->data ) : nullptr , batch_size,
96- top_p_val, vocab_size, deterministic, philox_seed, philox_offset, stream);
100+ top_p_val, vocab_size, probs-> strides [ 0 ], deterministic, philox_seed, philox_offset, stream);
97101 TVM_FFI_ICHECK (status == cudaSuccess)
98102 << " TopPSamplingFromProbs failed with error code " << cudaGetErrorString (status);
99103}
@@ -102,7 +106,8 @@ void top_k_sampling_from_probs(TensorView probs, TensorView output,
102106 Optional<TensorView> maybe_indices,
103107 Optional<TensorView> maybe_top_k_arr, int64_t top_k_val,
104108 bool deterministic, uint64_t philox_seed, uint64_t philox_offset) {
105- CHECK_INPUT (probs);
109+ CHECK_CUDA (probs);
110+ CHECK_LAST_DIM_CONTIGUOUS (probs);
106111 CHECK_INPUT (output);
107112 CHECK_DEVICE (output, probs);
108113 CHECK_DIM (2 , probs); // probs: (batch_size, vocab_size)
@@ -117,7 +122,7 @@ void top_k_sampling_from_probs(TensorView probs, TensorView output,
117122 static_cast <float *>(probs->data ), static_cast <int *>(output->data ),
118123 maybe_indices.has_value () ? static_cast <int *>(maybe_indices.value ()->data ) : nullptr ,
119124 has_top_k_arr ? static_cast <float *>(maybe_top_k_arr.value ()->data ) : nullptr , batch_size,
120- top_k_val, vocab_size, deterministic, philox_seed, philox_offset, stream);
125+ top_k_val, vocab_size, probs-> strides [ 0 ], deterministic, philox_seed, philox_offset, stream);
121126 TVM_FFI_ICHECK (status == cudaSuccess)
122127 << " TopKSamplingFromProbs failed with error code " << cudaGetErrorString (status);
123128}
@@ -126,7 +131,8 @@ void min_p_sampling_from_probs(TensorView probs, TensorView output,
126131 Optional<TensorView> maybe_indices,
127132 Optional<TensorView> maybe_min_p_arr, double min_p_val,
128133 bool deterministic, uint64_t philox_seed, uint64_t philox_offset) {
129- CHECK_INPUT (probs);
134+ CHECK_CUDA (probs);
135+ CHECK_LAST_DIM_CONTIGUOUS (probs);
130136 CHECK_INPUT (output);
131137 CHECK_DEVICE (output, probs);
132138 CHECK_DIM (2 , probs); // probs: (batch_size, vocab_size)
@@ -142,7 +148,8 @@ void min_p_sampling_from_probs(TensorView probs, TensorView output,
142148 has_min_p_arr ? static_cast <float *>(maybe_min_p_arr.value ()->data ) : nullptr ,
143149 static_cast <int *>(output->data ),
144150 maybe_indices.has_value () ? static_cast <int *>(maybe_indices.value ()->data ) : nullptr ,
145- batch_size, min_p_val, vocab_size, deterministic, philox_seed, philox_offset, stream);
151+ batch_size, min_p_val, vocab_size, probs->strides [0 ], deterministic, philox_seed,
152+ philox_offset, stream);
146153 TVM_FFI_ICHECK (status == cudaSuccess)
147154 << " MinPSamplingFromProb failed with error code " << cudaGetErrorString (status);
148155}
@@ -153,7 +160,8 @@ void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output,
153160 Optional<TensorView> maybe_top_p_arr, double top_p_val,
154161 bool deterministic, uint64_t philox_seed,
155162 uint64_t philox_offset) {
156- CHECK_INPUT (probs);
163+ CHECK_CUDA (probs);
164+ CHECK_LAST_DIM_CONTIGUOUS (probs);
157165 CHECK_INPUT (output);
158166 CHECK_DEVICE (output, probs);
159167 CHECK_DIM (2 , probs); // probs: (batch_size, vocab_size)
@@ -171,8 +179,8 @@ void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output,
171179 has_top_p_arr ? static_cast <float *>(maybe_top_p_arr.value ()->data ) : nullptr ,
172180 static_cast <int *>(output->data ),
173181 maybe_indices.has_value () ? static_cast <int *>(maybe_indices.value ()->data ) : nullptr ,
174- batch_size, top_k_val, top_p_val, vocab_size, deterministic, philox_seed, philox_offset ,
175- stream);
182+ batch_size, top_k_val, top_p_val, vocab_size, probs-> strides [ 0 ], deterministic, philox_seed ,
183+ philox_offset, stream);
176184 TVM_FFI_ICHECK (status == cudaSuccess)
177185 << " TopKTopPSamplingFromProbs failed with error code " << cudaGetErrorString (status);
178186}
0 commit comments