Skip to content

Commit 127f813

Browse files
add bfloat16 gradient option to split_table_batched_embeddings_benchmark (#721)
Summary: Pull Request resolved: #721 This Diff newly supports `--output-dtype=bf16` for both CPUs and GPUs. Reviewed By: jianyuh Differential Revision: D31432199 fbshipit-source-id: 90f75d17feabbc9d60fd6e3d2fd2de2ddd4dadf2
1 parent c016cb4 commit 127f813

16 files changed

+487
-135
lines changed

fbgemm_gpu/codegen/embedding_backward_code_generator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -400,14 +400,14 @@ def rowwise_adagrad() -> None:
400400
multiplier = shfl_sync(multiplier, 0);
401401
"""
402402
split_weight_update_cpu = """
403-
at::acc_type<scalar_t, true> g_local_sum_square = 0.0;
403+
at::acc_type<grad_t, true> g_local_sum_square = 0.0;
404404
for (int64_t d = 0; d < D; ++d) {
405405
g_local_sum_square += grad_buffer[d] * grad_buffer[d];
406406
}
407407
auto g_avg_square = g_local_sum_square / D;
408-
at::acc_type<scalar_t, true> new_sum_square_grads = momentum1_host[momentum1_offsets_data[feature_begin] + idx] + g_avg_square;
408+
at::acc_type<grad_t, true> new_sum_square_grads = momentum1_host[momentum1_offsets_data[feature_begin] + idx] + g_avg_square;
409409
momentum1_host[momentum1_offsets_data[feature_begin] + idx] = new_sum_square_grads;
410-
at::acc_type<scalar_t, true> multiplier;
410+
at::acc_type<grad_t, true> multiplier;
411411
multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps);
412412
for (int64_t d = 0; d < D; ++d) {
413413
host_weights_data[embedding_begin + d] -= grad_buffer[d] * multiplier;

fbgemm_gpu/codegen/embedding_backward_dense_host_cpu.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ class SplitLookupFunction_Dense_Op
6565
ctx->saved_data["total_hash_size_bits"] = total_hash_size_bits;
6666
ctx->saved_data["pooling_mode"] = pooling_mode;
6767

68+
int64_t output_dtype = -1 /* double */;
69+
if (host_weights.scalar_type() == at::kHalf ||
70+
host_weights.scalar_type() == at::ScalarType::Byte) {
71+
output_dtype = static_cast<int64_t>(SparseType::FP32);
72+
}
6873
return {split_embedding_codegen_forward_cpu(
6974
host_weights,
7075
weights_offsets,
@@ -74,7 +79,8 @@ class SplitLookupFunction_Dense_Op
7479
indices,
7580
offsets,
7681
pooling_mode,
77-
indice_weights_value)};
82+
indice_weights_value,
83+
output_dtype)};
7884
}
7985

8086
static torch::autograd::variable_list backward(

fbgemm_gpu/codegen/embedding_backward_split_cpu_approx_template.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ void split_embedding_backward_approx_cpu_kernel(
4343
const auto offsets_data = offsets.accessor<int64_t, 1>();
4444
// If indice_weights are not defined, then this accessor won't be used
4545
auto indice_weights_data = indice_weights.defined()
46-
? indice_weights.accessor<grad_t, 1>()
47-
: at::TensorAccessor<grad_t, 1>(nullptr, nullptr, nullptr);
46+
? indice_weights.accessor<at::acc_type<scalar_t, true>, 1>()
47+
: at::TensorAccessor<at::acc_type<scalar_t, true>, 1>(nullptr, nullptr, nullptr);
4848

4949
for (int64_t t = 0; t < T; ++t) {
5050
int feature_begin = t; // to conform interface with exact
@@ -68,8 +68,8 @@ void split_embedding_backward_approx_cpu_kernel(
6868
for (int64_t d = 0; d < D; ++d) {
6969
grad_buffer[d] = scale_factor *
7070
(indice_weights.defined()
71-
? grad_output_data[b][D_begin + d] * indice_weights_data[p]
72-
: grad_output_data[b][D_begin + d]);
71+
? static_cast<scalar_t>(grad_output_data[b][D_begin + d] * indice_weights_data[p])
72+
: static_cast<scalar_t>(grad_output_data[b][D_begin + d]));
7373
}
7474
{{ split_weight_update_cpu }};
7575
} // for each p
@@ -99,7 +99,8 @@ split_embedding_backward_codegen_{{ optimizer }}_cpu(
9999
{% if not dense %}
100100
bool stochastic_rounding,
101101
{% endif %}
102-
{{args.split_function_args | join(", ")}}
102+
{{args.split_function_args | join(", ")}},
103+
int64_t output_dtype
103104
) {
104105
int64_t T = D_offsets.numel() - 1;
105106
TORCH_CHECK(T > 0);
@@ -187,8 +188,11 @@ split_embedding_backward_codegen_{{ optimizer }}_cpu(
187188

188189
{% endif %}
189190

190-
AT_DISPATCH_FLOATING_TYPES(
191-
grad_output.scalar_type(), "split_embedding_backward_cpu", [&]() {
191+
AT_DISPATCH_FLOATING_TYPES_AND2(
192+
at::ScalarType::Half,
193+
at::ScalarType::BFloat16,
194+
grad_output.scalar_type(),
195+
"split_embedding_backward_cpu", [&]() {
192196
using grad_t = scalar_t;
193197
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
194198
host_weights.scalar_type(),

fbgemm_gpu/codegen/embedding_backward_split_cpu_template.cpp

Lines changed: 46 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ struct half2float16<at::Half> {
3232
} // namespace internal
3333

3434
namespace {
35-
template <typename scalar_t>
35+
template <typename scalar_t, typename grad_t>
3636
void split_embedding_backward_exact_cpu_kernel(
3737
Tensor grad_output,
3838
Tensor host_weights,
@@ -53,9 +53,6 @@ void split_embedding_backward_exact_cpu_kernel(
5353
const at::TensorAccessor<int64_t, 1> momentum2_offsets_data,
5454
{% endif %}
5555
{{ args.split_cpu_kernel_args | join(", ") }}) {
56-
using grad_t = at::acc_type<scalar_t, true>;
57-
58-
// const auto grad_output_accessor = grad_output.accessor<grad_t, 2>();
5956
const grad_t* grad_output_data = grad_output.data_ptr<grad_t>();
6057
auto host_weights_data = host_weights.accessor<scalar_t, 1>();
6158
const auto hash_size_cumsum_data = hash_size_cumsum.accessor<int64_t, 1>();
@@ -91,8 +88,8 @@ void split_embedding_backward_exact_cpu_kernel(
9188
offsets.accessor<int64_t, 1>(),
9289
indices.accessor<int64_t, 1>(),
9390
indice_weights.defined()
94-
? indice_weights.accessor<grad_t, 1>()
95-
: at::TensorAccessor<grad_t, 1>(nullptr, nullptr, nullptr),
91+
? indice_weights.accessor<at::acc_type<scalar_t, true>, 1>()
92+
: at::TensorAccessor<at::acc_type<scalar_t, true>, 1>(nullptr, nullptr, nullptr),
9693
pooling_mode,
9794
table_to_feature_offset + t,
9895
hash_size);
@@ -118,7 +115,8 @@ void split_embedding_backward_exact_cpu_kernel(
118115
table_to_feature_offset[t + 1] > table_to_feature_offset[t] + 1;
119116

120117
{% if optimizer == "rowwise_adagrad" %}
121-
constexpr bool use_fbgemm = std::is_same<scalar_t, float>::value;
118+
constexpr bool use_fbgemm = std::is_same<scalar_t, float>::value
119+
&& std::is_same<scalar_t, grad_t>::value;
122120
// || std::is_same<scalar_t, at::Half>::value;
123121
if (use_fbgemm && !is_shared_table) {
124122
// fbgemm handles common case of no shared table
@@ -181,11 +179,11 @@ void split_embedding_backward_exact_cpu_kernel(
181179
// no fbgemm
182180
// TODO: to parallelize, we should easily identify segments belong to
183181
// the same column.
184-
grad_t grad_buffer[D];
182+
at::acc_type<grad_t, true> grad_buffer[D];
185183
for (int c = c_begin; c < c_end; ++c) {
186184
int64_t idx = col_segment_indices[c];
187185
if (c == c_begin || col_segment_indices[c - 1] != idx) {
188-
memset(grad_buffer, 0, D * sizeof(grad_t));
186+
memset(grad_buffer, 0, D * sizeof(at::acc_type<grad_t, true>));
189187
}
190188
const int64_t embedding_begin = table_begin + idx * D;
191189
for (int r = col_segment_ptr[c]; r < col_segment_ptr[c + 1]; ++r) {
@@ -196,10 +194,12 @@ void split_embedding_backward_exact_cpu_kernel(
196194
}
197195
int b = batched_cscs[t].row_indices[r];
198196
for (int64_t d = 0; d < D; ++d) {
199-
grad_buffer[d] += batched_cscs[t].weights != nullptr
200-
? grad_output_data[b * grad_stride + D_offset + d] *
201-
batched_cscs[t].weights[r]
202-
: grad_output_data[b * grad_stride + D_offset + d];
197+
if (batched_cscs[t].weights != nullptr) {
198+
grad_buffer[d] += grad_output_data[b * grad_stride + D_offset + d] *
199+
batched_cscs[t].weights[r];
200+
} else {
201+
grad_buffer[d] += grad_output_data[b * grad_stride + D_offset + d];
202+
}
203203
}
204204
}
205205
if (c == c_end - 1 || col_segment_indices[c + 1] != idx) {
@@ -287,8 +287,11 @@ void split_embedding_backward_exact_cpu_dense_kernel(
287287
Tensor indice_weights,
288288
{% if not dense %}
289289
bool stochastic_rounding,
290-
{% endif %}
290+
{{ args.split_function_args | join(", ") }},
291+
int64_t output_dtype
292+
{% else %}
291293
{{ args.split_function_args | join(", ") }}
294+
{% endif %}
292295
) {
293296

294297
int64_t T = D_offsets.numel() - 1;
@@ -326,28 +329,35 @@ void split_embedding_backward_exact_cpu_dense_kernel(
326329

327330
grad_output = grad_output.contiguous();
328331

329-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
330-
host_weights.scalar_type(), "split_embedding_backward_exact_cpu", [&]() {
331-
split_embedding_backward_exact_cpu_kernel<scalar_t>(
332-
grad_output,
333-
host_weights,
334-
weights_offsets_data,
335-
D_offsets_data,
336-
hash_size_cumsum,
337-
indices,
338-
offsets,
339-
pooling_mode,
340-
indice_weights,
341-
num_tables,
342-
B,
343-
table_to_feature_offset,
344-
{% if "momentum1_offsets" in args.split_function_arg_names %}
345-
momentum1_offsets_data,
346-
{% endif %}
347-
{% if "momentum2_offsets" in args.split_function_arg_names %}
348-
momentum2_offsets_data,
349-
{% endif %}
350-
{{ args.split_cpu_kernel_arg_constructors | join(", ") }});
332+
AT_DISPATCH_FLOATING_TYPES_AND2(
333+
at::ScalarType::Half,
334+
at::ScalarType::BFloat16,
335+
grad_output.scalar_type(),
336+
"split_embedding_backward_exact_cpu_outer", [&]() {
337+
using grad_t = scalar_t;
338+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
339+
host_weights.scalar_type(), "split_embedding_backward_exact_cpu", [&]() {
340+
split_embedding_backward_exact_cpu_kernel<scalar_t, grad_t>(
341+
grad_output,
342+
host_weights,
343+
weights_offsets_data,
344+
D_offsets_data,
345+
hash_size_cumsum,
346+
indices,
347+
offsets,
348+
pooling_mode,
349+
indice_weights,
350+
num_tables,
351+
B,
352+
table_to_feature_offset,
353+
{% if "momentum1_offsets" in args.split_function_arg_names %}
354+
momentum1_offsets_data,
355+
{% endif %}
356+
{% if "momentum2_offsets" in args.split_function_arg_names %}
357+
momentum2_offsets_data,
358+
{% endif %}
359+
{{ args.split_cpu_kernel_arg_constructors | join(", ") }});
360+
});
351361
});
352362

353363
return;

fbgemm_gpu/codegen/embedding_backward_split_host_cpu_template.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ void split_embedding_backward_codegen_{{ optimizer }}_cpu(
2828
int64_t pooling_mode,
2929
Tensor indice_weights,
3030
bool stochastic_rounding,
31-
{{ args.split_function_args | join(", ") }});
31+
{{ args.split_function_args | join(", ") }},
32+
int64_t output_dtype);
3233

3334
namespace {
3435

@@ -52,7 +53,8 @@ class SplitLookupFunction_{{ optimizer }}_Op : public torch::autograd::Function<
5253
bool gradient_clipping,
5354
double max_gradient,
5455
bool stochastic_rounding,
55-
{{ args.split_function_args | join(", ") }}) {
56+
{{ args.split_function_args | join(", ") }},
57+
int64_t output_dtype) {
5658
Tensor indice_weights_value = indice_weights.value_or(Tensor());
5759
Tensor feature_requires_grad_value =
5860
feature_requires_grad.value_or(Tensor());
@@ -67,6 +69,7 @@ class SplitLookupFunction_{{ optimizer }}_Op : public torch::autograd::Function<
6769
ctx->saved_data["gradient_clipping"] = gradient_clipping;
6870
ctx->saved_data["max_gradient"] = max_gradient;
6971
ctx->saved_data["stochastic_rounding"] = stochastic_rounding;
72+
ctx->saved_data["output_dtype"] = output_dtype;
7073

7174
{% for (var, _) in args.saved_data %}
7275
ctx->saved_data["{{ var }}"] = {{ var }};
@@ -81,7 +84,8 @@ class SplitLookupFunction_{{ optimizer }}_Op : public torch::autograd::Function<
8184
indices,
8285
offsets,
8386
pooling_mode,
84-
indice_weights_value)};
87+
indice_weights_value,
88+
output_dtype)};
8589
}
8690

8791
static torch::autograd::variable_list backward(
@@ -110,6 +114,7 @@ class SplitLookupFunction_{{ optimizer }}_Op : public torch::autograd::Function<
110114
auto gradient_clipping = ctx->saved_data["gradient_clipping"].toBool();
111115
auto max_gradient = ctx->saved_data["max_gradient"].toDouble();
112116
auto stochastic_rounding = ctx->saved_data["stochastic_rounding"].toBool();
117+
auto output_dtype = ctx->saved_data["output_dtype"].toInt();
113118

114119
{% for (var, ivalue_cast) in args.saved_data %}
115120
auto {{ var }} = ctx->saved_data["{{ var }}"].{{ ivalue_cast }}();
@@ -134,7 +139,8 @@ class SplitLookupFunction_{{ optimizer }}_Op : public torch::autograd::Function<
134139
pooling_mode,
135140
indice_weights,
136141
stochastic_rounding,
137-
{{ args.split_function_arg_names | join(", ") }});
142+
{{ args.split_function_arg_names | join(", ") }},
143+
output_dtype);
138144
// NOTE: MEAN pooling will not work with indice_weights!
139145
auto grad_indice_weights = indice_weights.defined()
140146
? split_embedding_codegen_grad_indice_weights_cpu(
@@ -163,7 +169,8 @@ class SplitLookupFunction_{{ optimizer }}_Op : public torch::autograd::Function<
163169
Variable(), // gradient_clipping
164170
Variable(), // max_gradient
165171
Variable(), // stochastic_rounding
166-
{{ args.split_variables | join(", ") }}
172+
{{ args.split_variables | join(", ") }},
173+
Variable(), // output_dtype
167174
};
168175
}
169176
};
@@ -204,7 +211,8 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function_cpu(
204211
gradient_clipping,
205212
max_gradient,
206213
stochastic_rounding,
207-
{{ args.split_function_arg_names | join(", ") }})[0];
214+
{{ args.split_function_arg_names | join(", ") }},
215+
output_dtype)[0];
208216
}
209217

210218
TORCH_LIBRARY_FRAGMENT(fb, m) {

fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ constexpr int32_t kCacheLocationMissing = -1;
1717
constexpr size_t kForwardMaxThreads = 512;
1818

1919
// TODO: optimization to use multiple warps per row.
20-
template <typename emb_t, typename cache_t, size_t kMaxVecsPerThread>
20+
template <typename emb_t, typename grad_t, typename cache_t, size_t kMaxVecsPerThread>
2121
__global__
2222
__launch_bounds__(kForwardMaxThreads) void {{ "dense" if dense else "split" }}_embedding_codegen_grad_indice_weights_kernel(
2323
// [\sum_t E_t x D_t]
24-
const at::PackedTensorAccessor32<at::acc_type<cache_t, true>, 2, at::RestrictPtrTraits>
24+
const at::PackedTensorAccessor32<grad_t, 2, at::RestrictPtrTraits>
2525
grad_output,
2626
at::PackedTensorAccessor64<emb_t, 1, at::RestrictPtrTraits> dev_weights,
2727
{% if not dense %}
@@ -92,7 +92,7 @@ __launch_bounds__(kForwardMaxThreads) void {{ "dense" if dense else "split" }}_e
9292
i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D;
9393
++i) {
9494
int32_t d = 4 * kWarpSize * i + threadIdx.x * 4;
95-
Vec4T<at::acc_type<cache_t, true>> go((&grad_output[b][0]) + D_start + d);
95+
Vec4T<at::acc_type<grad_t, true>> go((&grad_output[b][0]) + D_start + d);
9696
grad_out[i] = go;
9797
}
9898

@@ -213,18 +213,19 @@ Tensor {{ "dense" if dense else "split" }}_embedding_codegen_grad_indice_weights
213213
const auto B = (offsets.size(0) - 1) / T;
214214
TORCH_CHECK(B >= 0);
215215
TORCH_CHECK(max_D <= {{ max_embedding_dim }});
216-
auto grad_indice_weights = empty_like(indices, indices.options().dtype(grad_output.dtype()));
216+
auto grad_indice_weights = empty_like(indices, indices.options().dtype(at::toAccumulateType(grad_output.scalar_type(), true)));
217217
if (B == 0) {
218218
return grad_indice_weights;
219219
}
220220
feature_requires_grad = feature_requires_grad.defined() ? feature_requires_grad : at::empty({0}, indices.options().dtype(at::kInt));
221221
{% if not dense %}
222-
DISPATCH_EMB_CACHE_TYPES(
222+
DISPATCH_EMB_GRAD_CACHE_TYPES(
223223
{% else %}
224224
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
225225
{% endif %}
226226
dev_weights.type(),
227227
{% if not dense %}
228+
grad_output.type(),
228229
lxu_cache_weights.type(),
229230
{% endif %}
230231
"split_embedding_codegen_grad_indice_weights_kernel",
@@ -234,9 +235,11 @@ Tensor {{ "dense" if dense else "split" }}_embedding_codegen_grad_indice_weights
234235
{{ "dense" if dense else "split" }}_embedding_codegen_grad_indice_weights_kernel<
235236
{% if not dense %}
236237
emb_t,
238+
grad_t,
237239
cache_t,
238240
{% else %}
239241
scalar_t,
242+
at::acc_type<scalar_t, true>,
240243
scalar_t,
241244
{% endif %}
242245
{{ kMaxVecsPerThread }}><<<
@@ -245,10 +248,7 @@ Tensor {{ "dense" if dense else "split" }}_embedding_codegen_grad_indice_weights
245248
0,
246249
at::cuda::getCurrentCUDAStream()>>>(
247250
{% if not dense %}
248-
grad_output.packed_accessor32<
249-
at::acc_type<cache_t, true>,
250-
2,
251-
at::RestrictPtrTraits>(),
251+
grad_output.packed_accessor32<grad_t, 2, at::RestrictPtrTraits>(),
252252
dev_weights.packed_accessor64<emb_t, 1, at::RestrictPtrTraits>(),
253253
{% else %}
254254
grad_output.packed_accessor32<
@@ -271,7 +271,7 @@ Tensor {{ "dense" if dense else "split" }}_embedding_codegen_grad_indice_weights
271271
{% endif %}
272272
feature_requires_grad.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
273273
{% if not dense %}
274-
grad_indice_weights.packed_accessor32<at::acc_type<cache_t, true>, 1, at::RestrictPtrTraits>()
274+
grad_indice_weights.packed_accessor32<at::acc_type<grad_t, true>, 1, at::RestrictPtrTraits>()
275275
{% else %}
276276
grad_indice_weights.packed_accessor32<at::acc_type<scalar_t, true>, 1, at::RestrictPtrTraits>()
277277
{% endif %}

0 commit comments

Comments
 (0)