Skip to content

Commit 044931f

Browse files
authored
Make sure that vectorize_with_alignment produced vectorized global loads (#23182)
1 parent 1d353b6 commit 044931f

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

csrc/quantization/vectorization_utils.cuh

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,10 @@ __device__ inline void vectorize_with_alignment(
4141

4242
for (int i = tid; i < num_vec; i += stride) {
4343
vout_t tmp;
44-
vec_op(tmp, v_in[i]);
45-
v_out[i] = tmp;
44+
// Make a local copy of the entire pack
45+
vin_t src = v_in[i]; // <- encourages a single vector ld
46+
vec_op(tmp, src);
47+
v_out[i] = tmp; // <- encourages a single vector st
4648
}
4749
return;
4850
}
@@ -71,8 +73,10 @@ __device__ inline void vectorize_with_alignment(
7173
// 2. vectorize the main part
7274
for (int i = tid; i < num_vec; i += stride) {
7375
vout_t tmp;
74-
vec_op(tmp, v_in[i]);
75-
v_out[i] = tmp;
76+
// Make a local copy of the entire pack
77+
vin_t src = v_in[i]; // <- encourages a single vector ld
78+
vec_op(tmp, src);
79+
v_out[i] = tmp; // <- encourages a single vector st
7680
}
7781

7882
// 3. handle the tail
@@ -125,7 +129,8 @@ __device__ inline void vectorize_read_with_alignment(const InT* in, int len,
125129
auto* v_in = reinterpret_cast<const vin_t*>(in);
126130

127131
for (int i = tid; i < num_vec; i += stride) {
128-
vec_op(v_in[i]);
132+
vin_t tmp = v_in[i];
133+
vec_op(tmp);
129134
}
130135
return;
131136
}

0 commit comments

Comments
 (0)