Skip to content

Commit 50f139c

Browse files
authored
[STABLE ABI] Eliminate C10 and ATen dependencies. (#4137)
1 parent 29d85c0 commit 50f139c

File tree

6 files changed

+28
-43
lines changed

6 files changed

+28
-43
lines changed

src/libtorchaudio/rnnt/cpu/cpu_kernels.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
#include <libtorchaudio/rnnt/cpu/math.h>
44
#include <libtorchaudio/rnnt/options.h>
55
#include <libtorchaudio/rnnt/types.h>
6-
7-
#include <c10/util/Logging.h>
6+
#include <torch/headeronly/util/Exception.h>
87

98
#include <cstring>
109
#include <limits>
@@ -50,7 +49,7 @@ class TensorView {
5049
}
5150

5251
DTYPE& operator()(const std::vector<int>& indices) {
53-
TORCH_CHECK_EQ(indices.size(), dims_.size());
52+
STD_TORCH_CHECK(indices.size() == dims_.size());
5453
int index = indices.back();
5554
for (int i = indices.size() - 2; i >= 0; --i) {
5655
index += indices[i] * strides_[i];

src/libtorchaudio/rnnt/cpu/cpu_transducer.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ status_t Compute(
2828
DTYPE* gradients = nullptr) {
2929
const Options& options = workspace.GetOptions();
3030

31-
TORCH_CHECK_EQ(options.device_, CPU);
31+
STD_TORCH_CHECK(options.device_ == CPU);
3232

3333
const int& B = options.batchSize_;
3434
const int& maxT = options.maxSrcLen_;
@@ -91,7 +91,7 @@ status_t ComputeAlphas(
9191
DTYPE* alphas) {
9292
const Options& options = workspace.GetOptions();
9393

94-
TORCH_CHECK_EQ(options.device_, CPU);
94+
STD_TORCH_CHECK(options.device_ == CPU);
9595

9696
const int& B = options.batchSize_;
9797
const int& maxT = options.maxSrcLen_;
@@ -140,7 +140,7 @@ status_t ComputeBetas(
140140
DTYPE* betas) {
141141
const Options& options = workspace.GetOptions();
142142

143-
TORCH_CHECK_EQ(options.device_, CPU);
143+
STD_TORCH_CHECK(options.device_ == CPU);
144144

145145
const int& B = options.batchSize_;
146146
const int& maxT = options.maxSrcLen_;

src/libtorchaudio/rnnt/gpu/compute.cu

Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <c10/cuda/CUDAStream.h>
55
#include <torch/csrc/stable/library.h>
66
#include <torch/csrc/stable/ops.h>
7+
#include <torch/headeronly/core/Dispatch_v2.h>
78

89
namespace torchaudio {
910
namespace rnnt {
@@ -117,33 +118,21 @@ std::tuple<Tensor, Tensor> compute(
117118
/*int_data=*/reinterpret_cast<int*>(int_workspace.data_ptr()),
118119
/*int_size=*/int_workspace.numel());
119120

120-
switch (logits.scalar_type()) {
121-
case ScalarType::Float: {
122-
Compute</*DTYPE=*/float, /*CAST_DTYPE=*/float>(
123-
/*workspace=*/workspace,
124-
/*logits=*/reinterpret_cast<float*>(logits.data_ptr()),
125-
/*targets=*/reinterpret_cast<int*>(targets.data_ptr()),
126-
/*srcLengths=*/reinterpret_cast<int*>(logit_lengths.data_ptr()),
127-
/*tgtLengths=*/reinterpret_cast<int*>(target_lengths.data_ptr()),
128-
/*costs=*/reinterpret_cast<float*>(costs.data_ptr()),
129-
/*gradients=*/reinterpret_cast<float*>(gradients.data_ptr()));
130-
break;
131-
}
132-
case ScalarType::Half: {
133-
Compute</*DTYPE=*/c10::Half, /*CAST_DTYPE=*/float>(
134-
/*workspace=*/workspace,
135-
/*logits=*/reinterpret_cast<c10::Half*>(logits.data_ptr()),
136-
/*targets=*/reinterpret_cast<int*>(targets.data_ptr()),
137-
/*srcLengths=*/reinterpret_cast<int*>(logit_lengths.data_ptr()),
138-
/*tgtLengths=*/reinterpret_cast<int*>(target_lengths.data_ptr()),
139-
/*costs=*/reinterpret_cast<c10::Half*>(costs.data_ptr()),
140-
/*gradients=*/reinterpret_cast<c10::Half*>(gradients.data_ptr()));
141-
break;
142-
}
143-
default: {
144-
STD_TORCH_CHECK(false, "unreachable");
145-
}
146-
};
121+
THO_DISPATCH_V2(
122+
logits.scalar_type(),
123+
"rnnt:compute",
124+
AT_WRAP([&] {
125+
(Compute</*DTYPE=*/scalar_t, /*CAST_DTYPE=*/float>(
126+
/*workspace=*/workspace,
127+
/*logits=*/reinterpret_cast<scalar_t*>(logits.data_ptr()),
128+
/*targets=*/reinterpret_cast<int*>(targets.data_ptr()),
129+
/*srcLengths=*/reinterpret_cast<int*>(logit_lengths.data_ptr()),
130+
/*tgtLengths=*/reinterpret_cast<int*>(target_lengths.data_ptr()),
131+
/*costs=*/reinterpret_cast<scalar_t*>(costs.data_ptr()),
132+
/*gradients=*/reinterpret_cast<scalar_t*>(gradients.data_ptr())));
133+
}),
134+
ScalarType::Float,
135+
ScalarType::Half);
147136

148137
return std::make_tuple(costs, gradients);
149138
}

src/libtorchaudio/rnnt/gpu/half.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#pragma once
22

33
#ifdef USE_C10_HALF
4-
#include "c10/util/Half.h"
4+
#include <torch/headeronly/util/Half.h>
55
#endif // USE_C10_HALF
66

77
#include <libtorchaudio/rnnt/macros.h>

src/libtorchaudio/rnnt/workspace.h

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
#include <vector>
55

66
#include <libtorchaudio/rnnt/options.h>
7-
8-
#include <c10/util/Logging.h>
7+
#include <torch/headeronly/util/Exception.h>
98

109
namespace torchaudio {
1110
namespace rnnt {
@@ -29,7 +28,7 @@ class DtypeWorkspace {
2928
~DtypeWorkspace() {}
3029

3130
static int ComputeSizeFromOptions(const Options& options) {
32-
TORCH_CHECK_NE(options.device_, UNDEFINED);
31+
STD_TORCH_CHECK(options.device_ != UNDEFINED);
3332
return ComputeSizeForDenominators(options) +
3433
ComputeSizeForLogProbs(options) + ComputeSizeForAlphas(options) +
3534
ComputeSizeForBetas(options);
@@ -38,7 +37,7 @@ class DtypeWorkspace {
3837
void Free();
3938
void Reset(const Options& options, DTYPE* data, int size) {
4039
int needed_size = ComputeSizeFromOptions(options);
41-
TORCH_CHECK_LE(needed_size, size);
40+
STD_TORCH_CHECK(needed_size <= size);
4241
options_ = options;
4342
data_ = data;
4443
size_ = size;
@@ -100,7 +99,7 @@ class IntWorkspace {
10099

101100
void Reset(const Options& options, int* data, int size) {
102101
int needed_size = ComputeSizeFromOptions(options);
103-
TORCH_CHECK_LE(needed_size, size);
102+
STD_TORCH_CHECK(needed_size <= size);
104103
options_ = options;
105104
data_ = data;
106105
size_ = size;
@@ -111,11 +110,11 @@ class IntWorkspace {
111110
}
112111

113112
int* GetPointerToAlphaCounters() const {
114-
TORCH_CHECK_EQ(options_.device_, GPU);
113+
STD_TORCH_CHECK(options_.device_ == GPU);
115114
return data_;
116115
}
117116
int* GetPointerToBetaCounters() const {
118-
TORCH_CHECK_EQ(options_.device_, GPU);
117+
STD_TORCH_CHECK(options_.device_ == GPU);
119118
return GetPointerToAlphaCounters() + ComputeSizeForAlphaCounters(options_);
120119
}
121120

src/libtorchaudio/utils.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
#include <ATen/DynamicLibrary.h>
21
#include <libtorchaudio/utils.h>
3-
#include <torch/csrc/stable/tensor.h>
42

53
#ifdef USE_CUDA
64
#include <cuda.h>

0 commit comments

Comments
 (0)