Skip to content

Commit 06776f8

Browse files
authored
[STABLE ABI] Eliminate sizes, strides, mutable_data_ptr, const_data_ptr, new_zeros ops (#4146)
1 parent 85f4ce5 commit 06776f8

File tree

5 files changed

+24
-109
lines changed

5 files changed

+24
-109
lines changed

src/libtorchaudio/forced_align/cpu/compute.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,8 @@ std::tuple<Tensor, Tensor> compute(
208208
ScalarType::Long);
209209
const auto B = logProbs.size(0);
210210
const auto T = logProbs.size(1);
211-
Tensor paths = torchaudio::stable::new_zeros(targets, {B, T});
211+
Tensor paths = torch::stable::empty({B, T}, targets.scalar_type());
212+
torch::stable::zero_(paths);
212213
THO_DISPATCH_V2(
213214
logProbs.scalar_type(),
214215
"forced_align_impl",

src/libtorchaudio/forced_align/gpu/compute.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,8 @@ std::tuple<Tensor, Tensor> compute(
294294
auto B = logProbs.size(0);
295295
auto T = logProbs.size(1); // num frames
296296

297-
Tensor paths = torchaudio::stable::new_zeros(targets, {B, T}, /*dtype=*/std::nullopt, /*layout=*/std::nullopt, /*device=*/torch::stable::DeviceType::CPU);
297+
Tensor paths = torch::stable::empty({B, T}, targets.scalar_type());
298+
torch::stable::zero_(paths);
298299

299300
THO_DISPATCH_V2(logProbs.scalar_type(), "forced_align_impl", AT_WRAP([&] {
300301
if (targets.scalar_type() == ScalarType::Long) {

src/libtorchaudio/rnnt/cpu/compute.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,10 @@ std::tuple<Tensor, Tensor> compute(
114114
// when stable ABI Tensor supports mutable_data_ptr templates.
115115
Workspace<float> workspace(
116116
/*options=*/options,
117-
/*dtype_data=*/reinterpret_cast<float*>(float_workspace.data_ptr()),
117+
/*dtype_data=*/
118+
reinterpret_cast<float*>(float_workspace.mutable_data_ptr()),
118119
/*dtype_size=*/float_workspace.numel(),
119-
/*int_data=*/reinterpret_cast<int*>(int_workspace.data_ptr()),
120+
/*int_data=*/reinterpret_cast<int*>(int_workspace.mutable_data_ptr()),
120121
/*int_size=*/int_workspace.numel());
121122

122123
THO_DISPATCH_V2(

src/libtorchaudio/stable/ops.h

Lines changed: 12 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -17,52 +17,15 @@
1717
#include <c10/cuda/CUDAException.h>
1818
#endif
1919

20-
using torch::stable::Tensor;
21-
2220
namespace torchaudio::stable {
2321

24-
using Layout = int32_t;
25-
26-
// TODO: When sizes and strides are implemented in torch::stable,
27-
// eliminate sizes and strides function below.
28-
inline std::vector<int64_t> sizes(const Tensor& t) {
29-
int64_t* ptr;
30-
TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(t.get(), &ptr));
31-
std::vector<int64_t> r(ptr, ptr + t.dim());
32-
return r;
33-
}
34-
35-
inline std::vector<int64_t> strides(const Tensor& t) {
36-
int64_t* ptr;
37-
TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(t.get(), &ptr));
38-
std::vector<int64_t> r(ptr, ptr + t.dim());
39-
return r;
40-
}
41-
42-
// TODO: When https://github.com/pytorch/pytorch/pull/161891 lands,
43-
// eliminate mutable_data_ptr and const_data_ptr templates.
44-
#define aoti_torch_get_mutable_data_ptr aoti_torch_get_data_ptr
45-
#define aoti_torch_get_const_data_ptr aoti_torch_get_data_ptr
46-
template <typename T>
47-
T* mutable_data_ptr(const Tensor& t) {
48-
void* data_ptr{};
49-
TORCH_ERROR_CODE_CHECK(aoti_torch_get_mutable_data_ptr(t.get(), &data_ptr));
50-
return reinterpret_cast<T*>(data_ptr);
51-
}
52-
53-
template <typename T>
54-
const T* const_data_ptr(const Tensor& t) {
55-
const void* data_ptr{};
56-
TORCH_ERROR_CODE_CHECK(
57-
aoti_torch_get_const_data_ptr(t.get(), const_cast<void**>(&data_ptr)));
58-
return reinterpret_cast<const T*>(data_ptr);
59-
}
22+
using torch::stable::Tensor;
6023

61-
// TODO: When cpu is implemented in torch::stable, eliminate
62-
// cpu function below.
24+
// TODO: When cpu op is implemented in torch::stable, eliminate cpu
25+
// function below.
6326
inline Tensor cpu(const Tensor& self) {
64-
auto sizes_ = sizes(self);
65-
auto cpu_type = aoti_torch_device_type_cpu();
27+
auto sizes_ = self.sizes();
28+
int32_t cpu_type = static_cast<int32_t>(torch::stable::DeviceType::CPU);
6629
int32_t dtype;
6730
TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(self.get(), &dtype));
6831
int32_t layout;
@@ -83,10 +46,11 @@ inline Tensor cpu(const Tensor& self) {
8346
return result;
8447
}
8548

86-
// TODO:
49+
// TODO: When cuda op is implemented in torch::stable, eliminate cuda
50+
// function below.
8751
inline Tensor cuda(const Tensor& self, int32_t cuda_index) {
88-
auto sizes_ = sizes(self);
89-
auto cuda_type = aoti_torch_device_type_cuda();
52+
auto sizes_ = self.sizes();
53+
int32_t cuda_type = static_cast<int32_t>(torch::stable::DeviceType::CUDA);
9054
int32_t dtype;
9155
TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(self.get(), &dtype));
9256
int32_t layout;
@@ -107,69 +71,16 @@ inline Tensor cuda(const Tensor& self, int32_t cuda_index) {
10771
return result;
10872
}
10973

110-
// TODO: remove when torch::stable provides new_zeros
111-
inline Tensor new_zeros(
112-
const Tensor& self,
113-
std::vector<int64_t> size,
114-
std::optional<c10::ScalarType> dtype = std::nullopt,
115-
std::optional<Layout> layout = std::nullopt,
116-
std::optional<torch::stable::Device> device = std::nullopt,
117-
std::optional<bool> pin_memory = std::nullopt) {
118-
int32_t target_dtype{};
119-
if (dtype.has_value()) {
120-
target_dtype = torch::stable::detail::to<int32_t>(
121-
torch::stable::detail::from(dtype.value()));
122-
} else {
123-
TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(self.get(), &target_dtype));
124-
}
125-
126-
Layout layout_;
127-
if (layout.has_value()) {
128-
layout_ = layout.value();
129-
} else {
130-
TORCH_ERROR_CODE_CHECK(aoti_torch_get_layout(self.get(), &layout_));
131-
}
132-
133-
int32_t device_type;
134-
torch::stable::DeviceIndex device_index = 0;
135-
if (device.has_value()) {
136-
auto device_ = device.value();
137-
device_type = static_cast<int32_t>(device_.type());
138-
device_index = device_.index();
139-
} else {
140-
TORCH_ERROR_CODE_CHECK(
141-
aoti_torch_get_device_type(self.get(), &device_type));
142-
TORCH_ERROR_CODE_CHECK(
143-
aoti_torch_get_device_index(self.get(), &device_index));
144-
}
145-
146-
// TODO: pin_memory
147-
148-
AtenTensorHandle ret0;
149-
TORCH_ERROR_CODE_CHECK(aoti_torch_aten_new_empty(
150-
self.get(),
151-
size.data(),
152-
static_cast<int64_t>(size.size()),
153-
&target_dtype,
154-
&layout_,
155-
&device_type,
156-
device_index,
157-
nullptr, // pin_memory (nullptr for default)
158-
&ret0));
159-
160-
auto result = Tensor(ret0);
161-
torch::stable::zero_(result);
162-
return result;
163-
}
164-
16574
// An analog of item template function defined in
16675
// ATen/templates/TensorBody.h
16776
template <typename T>
16877
T item(const Tensor& self) {
16978
STD_TORCH_CHECK(
17079
self.numel() == 1, "item requires single element tensor input");
17180
if (self.is_cpu()) {
172-
return torchaudio::stable::const_data_ptr<T>(self)[0];
81+
// TODO: use `return self.const_data_ptr<T>()[0];` after torch
82+
// stable supports const_data_ptr templates.
83+
return reinterpret_cast<const T*>(self.const_data_ptr())[0];
17384
#ifdef USE_CUDA
17485
} else if (self.is_cuda()) {
17586
T value;

src/libtorchaudio/utils.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
// TODO: replace the include libtorchaudio/stable/ops.h with
66
// torch/stable/ops.h when torch::stable provides all required
7-
// features (torch::stable::item<T> or similar):
7+
// features (torch::stable::item<T> et al):
88
#include <libtorchaudio/stable/ops.h>
99

1010
namespace torchaudio {
@@ -25,7 +25,7 @@ using TensorAccessor = torch::headeronly::HeaderOnlyTensorAccessor<T, N>;
2525
// TODO: eliminate accessor<T, N>(t) in favor of t.accessor<T, N>
2626
// after Tensor::accessor is supported in stable ABI
2727
template <typename T, size_t N>
28-
inline TensorAccessor<T, N> accessor(Tensor t) {
28+
inline TensorAccessor<T, N> accessor(torch::stable::Tensor t) {
2929
return TensorAccessor<T, N>(
3030
reinterpret_cast<T*>(t.data_ptr()), t.sizes().data(), t.strides().data());
3131
}
@@ -42,7 +42,7 @@ using PackedTensorAccessor32 =
4242
// TODO: eliminate accessor<T, N>(t) in favor of t.accessor<T, N>
4343
// after Tensor::accessor is supported in stable ABI
4444
template <typename T, size_t N>
45-
inline PackedTensorAccessor32<T, N> packed_accessor32(Tensor t) {
45+
inline PackedTensorAccessor32<T, N> packed_accessor32(torch::stable::Tensor t) {
4646
return PackedTensorAccessor32<T, N>(
4747
static_cast<typename PackedTensorAccessor32<T, N>::PtrType>(t.data_ptr()),
4848
t.sizes().data(),
@@ -58,7 +58,8 @@ using PackedTensorAccessorSizeT =
5858
size_t>;
5959

6060
template <typename T, size_t N>
61-
inline PackedTensorAccessorSizeT<T, N> packed_accessor_size_t(Tensor t) {
61+
inline PackedTensorAccessorSizeT<T, N> packed_accessor_size_t(
62+
torch::stable::Tensor t) {
6263
return PackedTensorAccessorSizeT<T, N>(
6364
static_cast<typename PackedTensorAccessorSizeT<T, N>::PtrType>(
6465
t.data_ptr()),

0 commit comments

Comments
 (0)