Skip to content

Commit b68da9a

Browse files
committed
Eliminate new_zeros
1 parent 8b8fa6c commit b68da9a

File tree

4 files changed

+14
-67
lines changed

4 files changed

+14
-67
lines changed

src/libtorchaudio/forced_align/cpu/compute.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,8 @@ std::tuple<Tensor, Tensor> compute(
208208
ScalarType::Long);
209209
const auto B = logProbs.size(0);
210210
const auto T = logProbs.size(1);
211-
Tensor paths = torchaudio::stable::new_zeros(targets, {B, T});
211+
Tensor paths = torch::stable::empty({B, T}, targets.scalar_type());
212+
torch::stable::zero_(paths);
212213
THO_DISPATCH_V2(
213214
logProbs.scalar_type(),
214215
"forced_align_impl",

src/libtorchaudio/forced_align/gpu/compute.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,8 @@ std::tuple<Tensor, Tensor> compute(
294294
auto B = logProbs.size(0);
295295
auto T = logProbs.size(1); // num frames
296296

297-
Tensor paths = torchaudio::stable::new_zeros(targets, {B, T}, /*dtype=*/std::nullopt, /*layout=*/std::nullopt, /*device=*/torch::stable::DeviceType::CPU);
297+
Tensor paths = torch::stable::empty({B, T}, targets.scalar_type());
298+
torch::stable::zero_(paths);
298299

299300
THO_DISPATCH_V2(logProbs.scalar_type(), "forced_align_impl", AT_WRAP([&] {
300301
if (targets.scalar_type() == ScalarType::Long) {

src/libtorchaudio/stable/ops.h

Lines changed: 5 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,12 @@
1717
#include <c10/cuda/CUDAException.h>
1818
#endif
1919

20-
using torch::stable::Tensor;
21-
2220
namespace torchaudio::stable {
2321

24-
using Layout = int32_t;
22+
using torch::stable::Tensor;
2523

26-
// TODO: When cpu is implemented in torch::stable, eliminate
27-
// cpu function below.
24+
// TODO: When cpu op is implemented in torch::stable, eliminate cpu
25+
// function below.
2826
inline Tensor cpu(const Tensor& self) {
2927
auto sizes_ = self.sizes();
3028
int32_t cpu_type = static_cast<int32_t>(torch::stable::DeviceType::CPU);
@@ -48,7 +46,8 @@ inline Tensor cpu(const Tensor& self) {
4846
return result;
4947
}
5048

51-
// TODO:
49+
// TODO: When cuda op is implemented in torch::stable, eliminate cuda
50+
// function below.
5251
inline Tensor cuda(const Tensor& self, int32_t cuda_index) {
5352
auto sizes_ = self.sizes();
5453
int32_t cuda_type = static_cast<int32_t>(torch::stable::DeviceType::CUDA);
@@ -72,61 +71,6 @@ inline Tensor cuda(const Tensor& self, int32_t cuda_index) {
7271
return result;
7372
}
7473

75-
// TODO: remove when torch::stable provides new_zeros
76-
inline Tensor new_zeros(
77-
const Tensor& self,
78-
std::vector<int64_t> size,
79-
std::optional<c10::ScalarType> dtype = std::nullopt,
80-
std::optional<Layout> layout = std::nullopt,
81-
std::optional<torch::stable::Device> device = std::nullopt,
82-
std::optional<bool> pin_memory = std::nullopt) {
83-
int32_t target_dtype{};
84-
if (dtype.has_value()) {
85-
target_dtype = torch::stable::detail::to<int32_t>(
86-
torch::stable::detail::from(dtype.value()));
87-
} else {
88-
TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(self.get(), &target_dtype));
89-
}
90-
91-
Layout layout_;
92-
if (layout.has_value()) {
93-
layout_ = layout.value();
94-
} else {
95-
TORCH_ERROR_CODE_CHECK(aoti_torch_get_layout(self.get(), &layout_));
96-
}
97-
98-
int32_t device_type;
99-
torch::stable::DeviceIndex device_index = 0;
100-
if (device.has_value()) {
101-
auto device_ = device.value();
102-
device_type = static_cast<int32_t>(device_.type());
103-
device_index = device_.index();
104-
} else {
105-
TORCH_ERROR_CODE_CHECK(
106-
aoti_torch_get_device_type(self.get(), &device_type));
107-
TORCH_ERROR_CODE_CHECK(
108-
aoti_torch_get_device_index(self.get(), &device_index));
109-
}
110-
111-
// TODO: pin_memory
112-
113-
AtenTensorHandle ret0;
114-
TORCH_ERROR_CODE_CHECK(aoti_torch_aten_new_empty(
115-
self.get(),
116-
size.data(),
117-
static_cast<int64_t>(size.size()),
118-
&target_dtype,
119-
&layout_,
120-
&device_type,
121-
device_index,
122-
nullptr, // pin_memory (nullptr for default)
123-
&ret0));
124-
125-
auto result = Tensor(ret0);
126-
torch::stable::zero_(result);
127-
return result;
128-
}
129-
13074
// An analog of item template function defined in
13175
// ATen/templates/TensorBody.h
13276
template <typename T>

src/libtorchaudio/utils.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
// TODO: replace the include libtorchaudio/stable/ops.h with
66
// torch/stable/ops.h when torch::stable provides all required
7-
// features (torch::stable::item<T> or similar):
7+
// features (torch::stable::item<T> et al):
88
#include <libtorchaudio/stable/ops.h>
99

1010
namespace torchaudio {
@@ -25,7 +25,7 @@ using TensorAccessor = torch::headeronly::HeaderOnlyTensorAccessor<T, N>;
2525
// TODO: eliminate accessor<T, N>(t) in favor of t.accessor<T, N>
2626
// after Tensor::accessor is supported in stable ABI
2727
template <typename T, size_t N>
28-
inline TensorAccessor<T, N> accessor(Tensor t) {
28+
inline TensorAccessor<T, N> accessor(torch::stable::Tensor t) {
2929
return TensorAccessor<T, N>(
3030
reinterpret_cast<T*>(t.data_ptr()), t.sizes().data(), t.strides().data());
3131
}
@@ -42,7 +42,7 @@ using PackedTensorAccessor32 =
4242
// TODO: eliminate accessor<T, N>(t) in favor of t.accessor<T, N>
4343
// after Tensor::accessor is supported in stable ABI
4444
template <typename T, size_t N>
45-
inline PackedTensorAccessor32<T, N> packed_accessor32(Tensor t) {
45+
inline PackedTensorAccessor32<T, N> packed_accessor32(torch::stable::Tensor t) {
4646
return PackedTensorAccessor32<T, N>(
4747
static_cast<typename PackedTensorAccessor32<T, N>::PtrType>(t.data_ptr()),
4848
t.sizes().data(),
@@ -58,7 +58,8 @@ using PackedTensorAccessorSizeT =
5858
size_t>;
5959

6060
template <typename T, size_t N>
61-
inline PackedTensorAccessorSizeT<T, N> packed_accessor_size_t(Tensor t) {
61+
inline PackedTensorAccessorSizeT<T, N> packed_accessor_size_t(
62+
torch::stable::Tensor t) {
6263
return PackedTensorAccessorSizeT<T, N>(
6364
static_cast<typename PackedTensorAccessorSizeT<T, N>::PtrType>(
6465
t.data_ptr()),

0 commit comments

Comments
 (0)