1717#include < c10/cuda/CUDAException.h>
1818#endif
1919
20- using torch::stable::Tensor;
21-
2220namespace torchaudio ::stable {
2321
24- using Layout = int32_t ;
25-
26- // TODO: When sizes and strides are implemented in torch::stable,
27- // eliminate sizes and strides function below.
28- inline std::vector<int64_t > sizes (const Tensor& t) {
29- int64_t * ptr;
30- TORCH_ERROR_CODE_CHECK (aoti_torch_get_sizes (t.get (), &ptr));
31- std::vector<int64_t > r (ptr, ptr + t.dim ());
32- return r;
33- }
34-
35- inline std::vector<int64_t > strides (const Tensor& t) {
36- int64_t * ptr;
37- TORCH_ERROR_CODE_CHECK (aoti_torch_get_strides (t.get (), &ptr));
38- std::vector<int64_t > r (ptr, ptr + t.dim ());
39- return r;
40- }
41-
42- // TODO: When https://github.com/pytorch/pytorch/pull/161891 lands,
43- // eliminate mutable_data_ptr and const_data_ptr templates.
44- #define aoti_torch_get_mutable_data_ptr aoti_torch_get_data_ptr
45- #define aoti_torch_get_const_data_ptr aoti_torch_get_data_ptr
46- template <typename T>
47- T* mutable_data_ptr (const Tensor& t) {
48- void * data_ptr{};
49- TORCH_ERROR_CODE_CHECK (aoti_torch_get_mutable_data_ptr (t.get (), &data_ptr));
50- return reinterpret_cast <T*>(data_ptr);
51- }
52-
53- template <typename T>
54- const T* const_data_ptr (const Tensor& t) {
55- const void * data_ptr{};
56- TORCH_ERROR_CODE_CHECK (
57- aoti_torch_get_const_data_ptr (t.get (), const_cast <void **>(&data_ptr)));
58- return reinterpret_cast <const T*>(data_ptr);
59- }
22+ using torch::stable::Tensor;
6023
61- // TODO: When cpu is implemented in torch::stable, eliminate
62- // cpu function below.
24+ // TODO: When cpu op is implemented in torch::stable, eliminate cpu
25+ // function below.
6326inline Tensor cpu (const Tensor& self) {
64- auto sizes_ = sizes (self );
65- auto cpu_type = aoti_torch_device_type_cpu ( );
27+ auto sizes_ = self. sizes ();
28+ int32_t cpu_type = static_cast < int32_t >(torch::stable::DeviceType::CPU );
6629 int32_t dtype;
6730 TORCH_ERROR_CODE_CHECK (aoti_torch_get_dtype (self.get (), &dtype));
6831 int32_t layout;
@@ -83,10 +46,11 @@ inline Tensor cpu(const Tensor& self) {
8346 return result;
8447}
8548
86- // TODO:
49+ // TODO: When cuda op is implemented in torch::stable, eliminate cuda
50+ // function below.
8751inline Tensor cuda (const Tensor& self, int32_t cuda_index) {
88- auto sizes_ = sizes (self );
89- auto cuda_type = aoti_torch_device_type_cuda ( );
52+ auto sizes_ = self. sizes ();
53+ int32_t cuda_type = static_cast < int32_t >(torch::stable::DeviceType::CUDA );
9054 int32_t dtype;
9155 TORCH_ERROR_CODE_CHECK (aoti_torch_get_dtype (self.get (), &dtype));
9256 int32_t layout;
@@ -107,69 +71,16 @@ inline Tensor cuda(const Tensor& self, int32_t cuda_index) {
10771 return result;
10872}
10973
110- // TODO: remove when torch::stable provides new_zeros
111- inline Tensor new_zeros (
112- const Tensor& self,
113- std::vector<int64_t > size,
114- std::optional<c10::ScalarType> dtype = std::nullopt ,
115- std::optional<Layout> layout = std::nullopt ,
116- std::optional<torch::stable::Device> device = std::nullopt ,
117- std::optional<bool > pin_memory = std::nullopt ) {
118- int32_t target_dtype{};
119- if (dtype.has_value ()) {
120- target_dtype = torch::stable::detail::to<int32_t >(
121- torch::stable::detail::from (dtype.value ()));
122- } else {
123- TORCH_ERROR_CODE_CHECK (aoti_torch_get_dtype (self.get (), &target_dtype));
124- }
125-
126- Layout layout_;
127- if (layout.has_value ()) {
128- layout_ = layout.value ();
129- } else {
130- TORCH_ERROR_CODE_CHECK (aoti_torch_get_layout (self.get (), &layout_));
131- }
132-
133- int32_t device_type;
134- torch::stable::DeviceIndex device_index = 0 ;
135- if (device.has_value ()) {
136- auto device_ = device.value ();
137- device_type = static_cast <int32_t >(device_.type ());
138- device_index = device_.index ();
139- } else {
140- TORCH_ERROR_CODE_CHECK (
141- aoti_torch_get_device_type (self.get (), &device_type));
142- TORCH_ERROR_CODE_CHECK (
143- aoti_torch_get_device_index (self.get (), &device_index));
144- }
145-
146- // TODO: pin_memory
147-
148- AtenTensorHandle ret0;
149- TORCH_ERROR_CODE_CHECK (aoti_torch_aten_new_empty (
150- self.get (),
151- size.data (),
152- static_cast <int64_t >(size.size ()),
153- &target_dtype,
154- &layout_,
155- &device_type,
156- device_index,
157- nullptr , // pin_memory (nullptr for default)
158- &ret0));
159-
160- auto result = Tensor (ret0);
161- torch::stable::zero_ (result);
162- return result;
163- }
164-
16574// An analog of item template function defined in
16675// ATen/templates/TensorBody.h
16776template <typename T>
16877T item (const Tensor& self) {
16978 STD_TORCH_CHECK (
17079 self.numel () == 1 , " item requires single element tensor input" );
17180 if (self.is_cpu ()) {
172- return torchaudio::stable::const_data_ptr<T>(self)[0 ];
81+ // TODO: use `return self.const_data_ptr<T>()[0];` after torch
82+ // stable supports const_data_ptr templates.
83+ return reinterpret_cast <const T*>(self.const_data_ptr ())[0 ];
17384#ifdef USE_CUDA
17485 } else if (self.is_cuda ()) {
17586 T value;
0 commit comments