|
4 | 4 | #include <c10/cuda/CUDAStream.h> |
5 | 5 | #include <torch/csrc/stable/library.h> |
6 | 6 | #include <torch/csrc/stable/ops.h> |
| 7 | +#include <torch/headeronly/core/Dispatch_v2.h> |
7 | 8 |
|
8 | 9 | namespace torchaudio { |
9 | 10 | namespace rnnt { |
@@ -117,33 +118,21 @@ std::tuple<Tensor, Tensor> compute( |
117 | 118 | /*int_data=*/reinterpret_cast<int*>(int_workspace.data_ptr()), |
118 | 119 | /*int_size=*/int_workspace.numel()); |
119 | 120 |
|
120 | | - switch (logits.scalar_type()) { |
121 | | - case ScalarType::Float: { |
122 | | - Compute</*DTYPE=*/float, /*CAST_DTYPE=*/float>( |
123 | | - /*workspace=*/workspace, |
124 | | - /*logits=*/reinterpret_cast<float*>(logits.data_ptr()), |
125 | | - /*targets=*/reinterpret_cast<int*>(targets.data_ptr()), |
126 | | - /*srcLengths=*/reinterpret_cast<int*>(logit_lengths.data_ptr()), |
127 | | - /*tgtLengths=*/reinterpret_cast<int*>(target_lengths.data_ptr()), |
128 | | - /*costs=*/reinterpret_cast<float*>(costs.data_ptr()), |
129 | | - /*gradients=*/reinterpret_cast<float*>(gradients.data_ptr())); |
130 | | - break; |
131 | | - } |
132 | | - case ScalarType::Half: { |
133 | | - Compute</*DTYPE=*/c10::Half, /*CAST_DTYPE=*/float>( |
134 | | - /*workspace=*/workspace, |
135 | | - /*logits=*/reinterpret_cast<c10::Half*>(logits.data_ptr()), |
136 | | - /*targets=*/reinterpret_cast<int*>(targets.data_ptr()), |
137 | | - /*srcLengths=*/reinterpret_cast<int*>(logit_lengths.data_ptr()), |
138 | | - /*tgtLengths=*/reinterpret_cast<int*>(target_lengths.data_ptr()), |
139 | | - /*costs=*/reinterpret_cast<c10::Half*>(costs.data_ptr()), |
140 | | - /*gradients=*/reinterpret_cast<c10::Half*>(gradients.data_ptr())); |
141 | | - break; |
142 | | - } |
143 | | - default: { |
144 | | - STD_TORCH_CHECK(false, "unreachable"); |
145 | | - } |
146 | | - }; |
| 121 | + THO_DISPATCH_V2( |
| 122 | + logits.scalar_type(), |
| 123 | + "rnnt:compute", |
| 124 | + AT_WRAP([&] { |
| 125 | + (Compute</*DTYPE=*/scalar_t, /*CAST_DTYPE=*/float>( |
| 126 | + /*workspace=*/workspace, |
| 127 | + /*logits=*/reinterpret_cast<scalar_t*>(logits.data_ptr()), |
| 128 | + /*targets=*/reinterpret_cast<int*>(targets.data_ptr()), |
| 129 | + /*srcLengths=*/reinterpret_cast<int*>(logit_lengths.data_ptr()), |
| 130 | + /*tgtLengths=*/reinterpret_cast<int*>(target_lengths.data_ptr()), |
| 131 | + /*costs=*/reinterpret_cast<scalar_t*>(costs.data_ptr()), |
| 132 | + /*gradients=*/reinterpret_cast<scalar_t*>(gradients.data_ptr()))); |
| 133 | + }), |
| 134 | + ScalarType::Float, |
| 135 | + ScalarType::Half); |
147 | 136 |
|
148 | 137 | return std::make_tuple(costs, gradients); |
149 | 138 | } |
|
0 commit comments