Skip to content

Commit 4454ac5

Browse files
jspark1105facebook-github-bot
authored andcommitted
add bf16 support in jagged tensor ops (#1079)
Summary: Pull Request resolved: #1079 To support bf16 training Reviewed By: ajtulloch Differential Revision: D35955466 fbshipit-source-id: 0f740f29074576c026005362c78f872fec80bbcc
1 parent dfb36cd commit 4454ac5

File tree

2 files changed

+20
-5
lines changed

2 files changed

+20
-5
lines changed

fbgemm_gpu/src/jagged_tensor_ops.cu

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -428,8 +428,9 @@ at::Tensor jagged_to_padded_dense_forward(
428428
Tensor padded_values_view =
429429
values.dim() == 1 ? padded_values.unsqueeze(-1) : padded_values;
430430

431-
AT_DISPATCH_ALL_TYPES_AND(
431+
AT_DISPATCH_ALL_TYPES_AND2(
432432
at::ScalarType::Half,
433+
at::ScalarType::BFloat16,
433434
values.scalar_type(),
434435
"jagged_to_padded_dense",
435436
[&] {
@@ -461,7 +462,9 @@ at::Tensor jagged_to_padded_dense_backward(
461462
auto grad_values =
462463
at::zeros({max_lengths[0], D}, grad_padded_values.options());
463464

464-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
465+
AT_DISPATCH_FLOATING_TYPES_AND2(
466+
at::ScalarType::Half,
467+
at::ScalarType::BFloat16,
465468
grad_padded_values.scalar_type(),
466469
"jagged_2d_to_dense_backward_kernel",
467470
[&] {
@@ -904,7 +907,9 @@ class BatchedDenseVecJagged2DMulGPUOp
904907

905908
AT_DISPATCH_INDEX_TYPES(
906909
a_offsets.scalar_type(), "dense_vec_jagged_2d_bmm_kernel_1", [&] {
907-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
910+
AT_DISPATCH_FLOATING_TYPES_AND2(
911+
at::ScalarType::Half,
912+
at::ScalarType::BFloat16,
908913
a_values.scalar_type(),
909914
"dense_vec_jagged_2d_bmm_kernel_2",
910915
[&] {
@@ -963,7 +968,9 @@ class BatchedDenseVecJagged2DMulGPUOp
963968
a_offsets.scalar_type(),
964969
"dense_vec_jagged_2d_bmm_baackward_kernel_1",
965970
[&] {
966-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
971+
AT_DISPATCH_FLOATING_TYPES_AND2(
972+
at::ScalarType::Half,
973+
at::ScalarType::BFloat16,
967974
grad_outputs[0].scalar_type(),
968975
"dense_vec_jagged_2d_bmm_baackward_kernel_2",
969976
[&] {

fbgemm_gpu/test/sparse_ops_test.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1836,6 +1836,7 @@ def test_dense_to_jagged(
18361836
outer_dense_size=st.integers(0, 5),
18371837
inner_dense_size=st.integers(0, 5),
18381838
padding_value=st.sampled_from([0, -1e-8]),
1839+
dtype=st.sampled_from([torch.float, torch.half, torch.bfloat16, torch.double]),
18391840
use_cpu=st.booleans() if gpu_available else st.just(True),
18401841
)
18411842
@settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None)
@@ -1845,8 +1846,12 @@ def test_jagged_to_padded_dense(
18451846
outer_dense_size: int,
18461847
inner_dense_size: int,
18471848
padding_value: float,
1849+
dtype: torch.dtype,
18481850
use_cpu: bool,
18491851
) -> None:
1852+
# CPU doesn't support bfloat16
1853+
assume(not use_cpu or dtype != torch.bfloat16)
1854+
18501855
# Testing with a basic crafted example.
18511856
# dense representation is
18521857
# [[[[0, 1], [ 0, 0], [0, 0]],
@@ -2006,7 +2011,7 @@ def mul_func(*args) -> torch.Tensor:
20062011
H=st.integers(1, 3),
20072012
max_L=st.integers(1, 32),
20082013
D=st.integers(0, 32),
2009-
dtype=st.sampled_from([torch.float, torch.half, torch.double]),
2014+
dtype=st.sampled_from([torch.float, torch.half, torch.bfloat16, torch.double]),
20102015
use_cpu=st.booleans() if gpu_available else st.just(True),
20112016
)
20122017
def test_batched_dense_vec_jagged_2d_mul(
@@ -2019,6 +2024,9 @@ def test_batched_dense_vec_jagged_2d_mul(
20192024
use_cpu: bool,
20202025
) -> None:
20212026
assume(H == 1 or B != 0)
2027+
# CPU doesn't support bfloat16
2028+
assume(not use_cpu or dtype != torch.bfloat16)
2029+
20222030
device = torch.device("cpu" if use_cpu else "cuda")
20232031
torch.backends.cuda.matmul.allow_tf32 = False
20242032

0 commit comments

Comments
 (0)