Skip to content

Commit a0ec9d2

Browse files
szkarpinskicyyever
authored andcommitted
Coalesce stores in Slice for smaller output types (NVIDIA#3568)
* Coalesce stores in Slice for smaller output types This change coalesces stores to global memory in SliceGPU when OutputType is smaller than 4 bytes in order to improve performance. Signed-off-by: Szymon Karpiński <hugo@staszic.waw.pl>
1 parent b2cc29a commit a0ec9d2

File tree

1 file changed

+51
-38
lines changed

1 file changed

+51
-38
lines changed

dali/kernels/slice/slice_gpu.cuh

+51-38
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ struct SliceBlockDesc {
7272
uint64_t size;
7373
};
7474

75+
template<typename OutputType>
76+
constexpr int coalesced_values = sizeof(OutputType) >= 4 ? 1 : 4 / sizeof(OutputType);
77+
7578
/**
7679
* @brief Simplified algorithm when no padding is necessary
7780
* @remarks `in` already refers to the slice anchor start
@@ -87,18 +90,22 @@ __device__ void SliceFuncNoPad(OutputType *__restrict__ out, const InputType *__
8790
return;
8891
}
8992

90-
for (; offset < block_end; offset += blockDim.x) {
91-
uint64_t idx = offset;
92-
uint64_t out_idx = idx;
93-
uint64_t in_idx = 0;
94-
93+
for (; offset < block_end; offset += blockDim.x * coalesced_values<OutputType>) {
9594
#pragma unroll
96-
for (int d = 0; d < Dims; d++) {
97-
int i_d = div_mod(idx, idx, out_strides[d]);
98-
in_idx += i_d * in_strides[d];
95+
for (uint64_t i = 0; i < coalesced_values<OutputType>; i++) {
96+
uint64_t idx = offset + i;
97+
if (idx >= block_end) break;
98+
uint64_t out_idx = idx;
99+
uint64_t in_idx = 0;
100+
101+
#pragma unroll
102+
for (int d = 0; d < Dims; d++) {
103+
int i_d = div_mod(idx, idx, out_strides[d]);
104+
in_idx += i_d * in_strides[d];
105+
}
106+
in_idx += idx; // remaining dims have equal strides
107+
out[out_idx] = clamp<OutputType>(in[in_idx]);
99108
}
100-
in_idx += idx; // remaining dims have equal strides
101-
out[out_idx] = clamp<OutputType>(in[in_idx]);
102109
}
103110
}
104111

@@ -131,44 +138,50 @@ __device__ void SliceFunc(OutputType *__restrict__ out, const InputType *__restr
131138
inner_in_extent = Dims > 1 ? in_strides[LastDim - 1] : in_shape[LastDim] * in_strides[LastDim];
132139
}
133140

134-
for (; offset < block_end; offset += blockDim.x) {
135-
uint64_t idx = offset;
136-
uint64_t out_idx = idx;
137-
138-
// If no dimensions were skipped (AllDims=true) we can avoid division in the last dimension,
139-
// because know the strides are 1 (or we treat them as 1 if we fused dimensions)
140-
int i_c = 0;
141-
int i_d;
142-
bool out_of_bounds = false;
143-
uint64_t in_idx = 0;
144-
141+
for (; offset < block_end; offset += blockDim.x * coalesced_values<OutputType>) {
142+
#ifndef __clang__
145143
#pragma unroll
146-
for (int d = 0; d < Dims - 1; d++) {
147-
i_d = div_mod(idx, idx, out_strides[d]);
148-
if (d == channel_dim)
144+
#endif
145+
for (uint64_t i = 0; i < coalesced_values<OutputType>; i++) {
146+
uint64_t idx = offset + i;
147+
if (idx >= block_end) break;
148+
uint64_t out_idx = idx;
149+
150+
// If no dimensions were skipped (AllDims=true) we can avoid division in the last dimension,
151+
// because know the strides are 1 (or we treat them as 1 if we fused dimensions)
152+
int i_c = 0;
153+
int i_d;
154+
bool out_of_bounds = false;
155+
uint64_t in_idx = 0;
156+
157+
#pragma unroll
158+
for (int d = 0; d < Dims - 1; d++) {
159+
i_d = div_mod(idx, idx, out_strides[d]);
160+
if (d == channel_dim)
161+
i_c = i_d;
162+
out_of_bounds |= is_out_of_bounds(anchor[d] + i_d, in_shape[d]);
163+
if (!out_of_bounds)
164+
in_idx += i_d * in_strides[d];
165+
}
166+
167+
constexpr int d = LastDim;
168+
i_d = idx; // out_strides[d] is 1
169+
if (AllDims && d == channel_dim)
149170
i_c = i_d;
150-
out_of_bounds |= is_out_of_bounds(anchor[d] + i_d, in_shape[d]);
171+
out_of_bounds |= is_out_of_bounds(inner_in_anchor + i_d, inner_in_extent);
151172
if (!out_of_bounds)
152-
in_idx += i_d * in_strides[d];
153-
}
154-
155-
constexpr int d = LastDim;
156-
i_d = idx; // out_strides[d] is 1
157-
if (AllDims && d == channel_dim)
158-
i_c = i_d;
159-
out_of_bounds |= is_out_of_bounds(inner_in_anchor + i_d, inner_in_extent);
160-
if (!out_of_bounds)
161-
in_idx += i_d; // in_strides[d] is 1
173+
in_idx += i_d; // in_strides[d] is 1
162174

163-
// Fill values are reused a lot, so let's make sure they are cached (by using __ldg())
164-
out[out_idx] = out_of_bounds ? __ldg(&fill_values[i_c]) : clamp<OutputType>(in[in_idx]);
175+
// Fill values are reused a lot, so let's make sure they are cached (by using __ldg())
176+
out[out_idx] = out_of_bounds ? __ldg(&fill_values[i_c]) : clamp<OutputType>(in[in_idx]);
177+
}
165178
}
166179
}
167180

168181
template <typename OutputType, typename InputType, int Dims, bool SupportPad>
169182
__global__ void SliceKernel(const SliceSampleDesc<Dims> *samples, const SliceBlockDesc *blocks) {
170183
int sampleIdx = blocks[blockIdx.x].sampleIdx;
171-
uint64_t offset = blocks[blockIdx.x].offset + threadIdx.x;
184+
uint64_t offset = blocks[blockIdx.x].offset + threadIdx.x * coalesced_values<OutputType>;
172185
uint64_t block_end = blocks[blockIdx.x].offset + blocks[blockIdx.x].size;
173186
auto sample = samples[sampleIdx];
174187
auto *out = static_cast<OutputType*>(sample.out);

0 commit comments

Comments
 (0)