Skip to content

Commit 8b8fa6c

Browse files
committed
Eliminate sizes, strides, mutable_data_ptr, const_data_ptr ops
1 parent 85f4ce5 commit 8b8fa6c

File tree

2 files changed

+10
-42
lines changed

2 files changed

+10
-42
lines changed

src/libtorchaudio/rnnt/cpu/compute.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,10 @@ std::tuple<Tensor, Tensor> compute(
114114
// when stable ABI Tensor supports mutable_data_ptr templates.
115115
Workspace<float> workspace(
116116
/*options=*/options,
117-
/*dtype_data=*/reinterpret_cast<float*>(float_workspace.data_ptr()),
117+
/*dtype_data=*/
118+
reinterpret_cast<float*>(float_workspace.mutable_data_ptr()),
118119
/*dtype_size=*/float_workspace.numel(),
119-
/*int_data=*/reinterpret_cast<int*>(int_workspace.data_ptr()),
120+
/*int_data=*/reinterpret_cast<int*>(int_workspace.mutable_data_ptr()),
120121
/*int_size=*/int_workspace.numel());
121122

122123
THO_DISPATCH_V2(

src/libtorchaudio/stable/ops.h

Lines changed: 7 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -23,46 +23,11 @@ namespace torchaudio::stable {
2323

2424
using Layout = int32_t;
2525

26-
// TODO: When sizes and strides are implemented in torch::stable,
27-
// eliminate sizes and strides function below.
28-
inline std::vector<int64_t> sizes(const Tensor& t) {
29-
int64_t* ptr;
30-
TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(t.get(), &ptr));
31-
std::vector<int64_t> r(ptr, ptr + t.dim());
32-
return r;
33-
}
34-
35-
inline std::vector<int64_t> strides(const Tensor& t) {
36-
int64_t* ptr;
37-
TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(t.get(), &ptr));
38-
std::vector<int64_t> r(ptr, ptr + t.dim());
39-
return r;
40-
}
41-
42-
// TODO: When https://github.com/pytorch/pytorch/pull/161891 lands,
43-
// eliminate mutable_data_ptr and const_data_ptr templates.
44-
#define aoti_torch_get_mutable_data_ptr aoti_torch_get_data_ptr
45-
#define aoti_torch_get_const_data_ptr aoti_torch_get_data_ptr
46-
template <typename T>
47-
T* mutable_data_ptr(const Tensor& t) {
48-
void* data_ptr{};
49-
TORCH_ERROR_CODE_CHECK(aoti_torch_get_mutable_data_ptr(t.get(), &data_ptr));
50-
return reinterpret_cast<T*>(data_ptr);
51-
}
52-
53-
template <typename T>
54-
const T* const_data_ptr(const Tensor& t) {
55-
const void* data_ptr{};
56-
TORCH_ERROR_CODE_CHECK(
57-
aoti_torch_get_const_data_ptr(t.get(), const_cast<void**>(&data_ptr)));
58-
return reinterpret_cast<const T*>(data_ptr);
59-
}
60-
6126
// TODO: When cpu is implemented in torch::stable, eliminate
6227
// cpu function below.
6328
inline Tensor cpu(const Tensor& self) {
64-
auto sizes_ = sizes(self);
65-
auto cpu_type = aoti_torch_device_type_cpu();
29+
auto sizes_ = self.sizes();
30+
int32_t cpu_type = static_cast<int32_t>(torch::stable::DeviceType::CPU);
6631
int32_t dtype;
6732
TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(self.get(), &dtype));
6833
int32_t layout;
@@ -85,8 +50,8 @@ inline Tensor cpu(const Tensor& self) {
8550

8651
// TODO:
8752
inline Tensor cuda(const Tensor& self, int32_t cuda_index) {
88-
auto sizes_ = sizes(self);
89-
auto cuda_type = aoti_torch_device_type_cuda();
53+
auto sizes_ = self.sizes();
54+
int32_t cuda_type = static_cast<int32_t>(torch::stable::DeviceType::CUDA);
9055
int32_t dtype;
9156
TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(self.get(), &dtype));
9257
int32_t layout;
@@ -169,7 +134,9 @@ T item(const Tensor& self) {
169134
STD_TORCH_CHECK(
170135
self.numel() == 1, "item requires single element tensor input");
171136
if (self.is_cpu()) {
172-
return torchaudio::stable::const_data_ptr<T>(self)[0];
137+
// TODO: use `return self.const_data_ptr<T>()[0];` after torch
138+
// stable supports const_data_ptr templates.
139+
return reinterpret_cast<const T*>(self.const_data_ptr())[0];
173140
#ifdef USE_CUDA
174141
} else if (self.is_cuda()) {
175142
T value;

0 commit comments

Comments
 (0)