Skip to content

Commit 2d16f61

Browse files
allow FP16-type grad_t (#1072)
Summary: Pull Request resolved: #1072 This Diff partially revives D31432199 (127f813), but only enables `grad_t = FP16` (no `BF16` support) to reduce the adverse side effect (e.g., the increase of binary size and compilation time). Specifically, D31432199 (127f813) provides FP32, FP16, and BF16 for `grad_t`. This Diff removes BF16 options for `grad_t` (so only FP32 and FP16 for `grad_t`). Reviewed By: jianyuh Differential Revision: D35120293 fbshipit-source-id: b9a1d35f901b26277a220360a2a68583c65c8554
1 parent 4454ac5 commit 2d16f61

File tree

4 files changed

+116
-48
lines changed

4 files changed

+116
-48
lines changed

fbgemm_gpu/codegen/embedding_backward_split_cpu_template.cpp

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -344,31 +344,35 @@ void split_embedding_backward_exact_cpu_dense_kernel(
344344

345345
grad_output = grad_output.contiguous();
346346

347+
347348
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
348-
host_weights.scalar_type(), "split_embedding_backward_exact_cpu", [&] {
349-
// TODO: respect output_dtype
350-
using grad_t = float;
351-
split_embedding_backward_exact_cpu_kernel<scalar_t, grad_t>(
352-
grad_output,
353-
host_weights,
354-
weights_offsets_data,
355-
D_offsets_data,
356-
hash_size_cumsum,
357-
indices,
358-
offsets,
359-
pooling_mode,
360-
indice_weights,
361-
num_tables,
362-
B,
363-
table_to_feature_offset,
364-
{% if "momentum1_offsets" in args.split_function_arg_names %}
365-
momentum1_offsets_data,
366-
{% endif %}
367-
{% if "momentum2_offsets" in args.split_function_arg_names %}
368-
momentum2_offsets_data,
369-
{% endif %}
370-
{{ args.split_cpu_kernel_arg_constructors | join(", ") }});
371-
});
349+
grad_output.scalar_type(),
350+
"split_embedding_backward_exact_cpu_outer", [&]() {
351+
using grad_t = scalar_t;
352+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
353+
host_weights.scalar_type(), "split_embedding_backward_exact_cpu", [&] {
354+
split_embedding_backward_exact_cpu_kernel<scalar_t, grad_t>(
355+
grad_output,
356+
host_weights,
357+
weights_offsets_data,
358+
D_offsets_data,
359+
hash_size_cumsum,
360+
indices,
361+
offsets,
362+
pooling_mode,
363+
indice_weights,
364+
num_tables,
365+
B,
366+
table_to_feature_offset,
367+
{% if "momentum1_offsets" in args.split_function_arg_names %}
368+
momentum1_offsets_data,
369+
{% endif %}
370+
{% if "momentum2_offsets" in args.split_function_arg_names %}
371+
momentum2_offsets_data,
372+
{% endif %}
373+
{{ args.split_cpu_kernel_arg_constructors | join(", ") }});
374+
});
375+
});
372376

373377
return;
374378

fbgemm_gpu/codegen/embedding_forward_split_cpu.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,8 @@ Tensor split_embedding_codegen_forward_cpu(
194194
// It is assumed that the indice_weights will always be float
195195
TORCH_CHECK(
196196
!indice_weights.defined() || indice_weights.scalar_type() != at::kHalf);
197-
AT_DISPATCH_FLOATING_TYPES(
198-
output.scalar_type(), "split_embedding_cpu_forward", [&] {
197+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
198+
output.scalar_type(), "split_embedding_cpu_forward", [&]() {
199199
using output_t = scalar_t;
200200
AT_DISPATCH_FLOATING_TYPES_AND2(
201201
at::ScalarType::Half,

fbgemm_gpu/include/fbgemm_gpu/dispatch_macros.h

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -137,20 +137,22 @@
137137
} \
138138
}
139139

140-
#define DISPATCH_EMB_GRAD_CACHE_TYPES( \
141-
EMB_TYPE, GRAD_TYPE, CACHE_TYPE, NAME, ...) \
142-
[&] { \
143-
const auto& emb_type = EMB_TYPE; \
144-
const auto& grad_type = GRAD_TYPE; \
145-
const auto& cache_type = CACHE_TYPE; \
146-
at::ScalarType _emb_t = ::detail::scalar_type(emb_type); \
147-
at::ScalarType _grad_t = ::detail::scalar_type(grad_type); \
148-
at::ScalarType _cache_t = ::detail::scalar_type(cache_type); \
149-
switch (_grad_t) { \
150-
PRIVATE_CASE_TYPE_CACHE_EMB( \
151-
at::ScalarType::Float, _cache_t, _emb_t, float, NAME, __VA_ARGS__) \
152-
default: \
153-
AT_ERROR( \
154-
#NAME, " not implemented for grad_t '", toString(_grad_t), "'"); \
155-
} \
140+
#define DISPATCH_EMB_GRAD_CACHE_TYPES( \
141+
EMB_TYPE, GRAD_TYPE, CACHE_TYPE, NAME, ...) \
142+
[&] { \
143+
const auto& emb_type = EMB_TYPE; \
144+
const auto& grad_type = GRAD_TYPE; \
145+
const auto& cache_type = CACHE_TYPE; \
146+
at::ScalarType _emb_t = ::detail::scalar_type(emb_type); \
147+
at::ScalarType _grad_t = ::detail::scalar_type(grad_type); \
148+
at::ScalarType _cache_t = ::detail::scalar_type(cache_type); \
149+
switch (_grad_t) { \
150+
PRIVATE_CASE_TYPE_CACHE_EMB( \
151+
at::ScalarType::Float, _cache_t, _emb_t, float, NAME, __VA_ARGS__) \
152+
PRIVATE_CASE_TYPE_CACHE_EMB( \
153+
at::ScalarType::Half, _cache_t, _emb_t, at::Half, NAME, __VA_ARGS__) \
154+
default: \
155+
AT_ERROR( \
156+
#NAME, " not implemented for grad_t '", toString(_grad_t), "'"); \
157+
} \
156158
}()

0 commit comments

Comments
 (0)