@@ -72,6 +72,9 @@ struct SliceBlockDesc {
72
72
uint64_t size;
73
73
};
74
74
75
+ template <typename OutputType>
76
+ constexpr int coalesced_values = sizeof (OutputType) >= 4 ? 1 : 4 / sizeof (OutputType);
77
+
75
78
/* *
76
79
* @brief Simplified algorithm when no padding is necessary
77
80
* @remarks `in` already refers to the slice anchor start
@@ -87,18 +90,22 @@ __device__ void SliceFuncNoPad(OutputType *__restrict__ out, const InputType *__
87
90
return ;
88
91
}
89
92
90
- for (; offset < block_end; offset += blockDim .x ) {
91
- uint64_t idx = offset;
92
- uint64_t out_idx = idx;
93
- uint64_t in_idx = 0 ;
94
-
93
+ for (; offset < block_end; offset += blockDim .x * coalesced_values<OutputType>) {
95
94
#pragma unroll
96
- for (int d = 0 ; d < Dims; d++) {
97
- int i_d = div_mod (idx, idx, out_strides[d]);
98
- in_idx += i_d * in_strides[d];
95
+ for (uint64_t i = 0 ; i < coalesced_values<OutputType>; i++) {
96
+ uint64_t idx = offset + i;
97
+ if (idx >= block_end) break ;
98
+ uint64_t out_idx = idx;
99
+ uint64_t in_idx = 0 ;
100
+
101
+ #pragma unroll
102
+ for (int d = 0 ; d < Dims; d++) {
103
+ int i_d = div_mod (idx, idx, out_strides[d]);
104
+ in_idx += i_d * in_strides[d];
105
+ }
106
+ in_idx += idx; // remaining dims have equal strides
107
+ out[out_idx] = clamp<OutputType>(in[in_idx]);
99
108
}
100
- in_idx += idx; // remaining dims have equal strides
101
- out[out_idx] = clamp<OutputType>(in[in_idx]);
102
109
}
103
110
}
104
111
@@ -131,44 +138,50 @@ __device__ void SliceFunc(OutputType *__restrict__ out, const InputType *__restr
131
138
inner_in_extent = Dims > 1 ? in_strides[LastDim - 1 ] : in_shape[LastDim] * in_strides[LastDim];
132
139
}
133
140
134
- for (; offset < block_end; offset += blockDim .x ) {
135
- uint64_t idx = offset;
136
- uint64_t out_idx = idx;
137
-
138
- // If no dimensions were skipped (AllDims=true) we can avoid division in the last dimension,
139
- // because know the strides are 1 (or we treat them as 1 if we fused dimensions)
140
- int i_c = 0 ;
141
- int i_d;
142
- bool out_of_bounds = false ;
143
- uint64_t in_idx = 0 ;
144
-
141
+ for (; offset < block_end; offset += blockDim .x * coalesced_values<OutputType>) {
142
+ #ifndef __clang__
145
143
#pragma unroll
146
- for (int d = 0 ; d < Dims - 1 ; d++) {
147
- i_d = div_mod (idx, idx, out_strides[d]);
148
- if (d == channel_dim)
144
+ #endif
145
+ for (uint64_t i = 0 ; i < coalesced_values<OutputType>; i++) {
146
+ uint64_t idx = offset + i;
147
+ if (idx >= block_end) break ;
148
+ uint64_t out_idx = idx;
149
+
150
+ // If no dimensions were skipped (AllDims=true) we can avoid division in the last dimension,
151
+ // because know the strides are 1 (or we treat them as 1 if we fused dimensions)
152
+ int i_c = 0 ;
153
+ int i_d;
154
+ bool out_of_bounds = false ;
155
+ uint64_t in_idx = 0 ;
156
+
157
+ #pragma unroll
158
+ for (int d = 0 ; d < Dims - 1 ; d++) {
159
+ i_d = div_mod (idx, idx, out_strides[d]);
160
+ if (d == channel_dim)
161
+ i_c = i_d;
162
+ out_of_bounds |= is_out_of_bounds (anchor[d] + i_d, in_shape[d]);
163
+ if (!out_of_bounds)
164
+ in_idx += i_d * in_strides[d];
165
+ }
166
+
167
+ constexpr int d = LastDim;
168
+ i_d = idx; // out_strides[d] is 1
169
+ if (AllDims && d == channel_dim)
149
170
i_c = i_d;
150
- out_of_bounds |= is_out_of_bounds (anchor[d] + i_d, in_shape[d] );
171
+ out_of_bounds |= is_out_of_bounds (inner_in_anchor + i_d, inner_in_extent );
151
172
if (!out_of_bounds)
152
- in_idx += i_d * in_strides[d];
153
- }
154
-
155
- constexpr int d = LastDim;
156
- i_d = idx; // out_strides[d] is 1
157
- if (AllDims && d == channel_dim)
158
- i_c = i_d;
159
- out_of_bounds |= is_out_of_bounds (inner_in_anchor + i_d, inner_in_extent);
160
- if (!out_of_bounds)
161
- in_idx += i_d; // in_strides[d] is 1
173
+ in_idx += i_d; // in_strides[d] is 1
162
174
163
- // Fill values are reused a lot, so let's make sure they are cached (by using __ldg())
164
- out[out_idx] = out_of_bounds ? __ldg (&fill_values[i_c]) : clamp<OutputType>(in[in_idx]);
175
+ // Fill values are reused a lot, so let's make sure they are cached (by using __ldg())
176
+ out[out_idx] = out_of_bounds ? __ldg (&fill_values[i_c]) : clamp<OutputType>(in[in_idx]);
177
+ }
165
178
}
166
179
}
167
180
168
181
template <typename OutputType, typename InputType, int Dims, bool SupportPad>
169
182
__global__ void SliceKernel (const SliceSampleDesc<Dims> *samples, const SliceBlockDesc *blocks) {
170
183
int sampleIdx = blocks[blockIdx .x ].sampleIdx ;
171
- uint64_t offset = blocks[blockIdx .x ].offset + threadIdx .x ;
184
+ uint64_t offset = blocks[blockIdx .x ].offset + threadIdx .x * coalesced_values<OutputType> ;
172
185
uint64_t block_end = blocks[blockIdx .x ].offset + blocks[blockIdx .x ].size ;
173
186
auto sample = samples[sampleIdx];
174
187
auto *out = static_cast <OutputType*>(sample.out );
0 commit comments