@@ -23,46 +23,11 @@ namespace torchaudio::stable {
2323
2424using Layout = int32_t ;
2525
26- // TODO: When sizes and strides are implemented in torch::stable,
27- // eliminate sizes and strides function below.
28- inline std::vector<int64_t > sizes (const Tensor& t) {
29- int64_t * ptr;
30- TORCH_ERROR_CODE_CHECK (aoti_torch_get_sizes (t.get (), &ptr));
31- std::vector<int64_t > r (ptr, ptr + t.dim ());
32- return r;
33- }
34-
35- inline std::vector<int64_t > strides (const Tensor& t) {
36- int64_t * ptr;
37- TORCH_ERROR_CODE_CHECK (aoti_torch_get_strides (t.get (), &ptr));
38- std::vector<int64_t > r (ptr, ptr + t.dim ());
39- return r;
40- }
41-
42- // TODO: When https://github.com/pytorch/pytorch/pull/161891 lands,
43- // eliminate mutable_data_ptr and const_data_ptr templates.
44- #define aoti_torch_get_mutable_data_ptr aoti_torch_get_data_ptr
45- #define aoti_torch_get_const_data_ptr aoti_torch_get_data_ptr
46- template <typename T>
47- T* mutable_data_ptr (const Tensor& t) {
48- void * data_ptr{};
49- TORCH_ERROR_CODE_CHECK (aoti_torch_get_mutable_data_ptr (t.get (), &data_ptr));
50- return reinterpret_cast <T*>(data_ptr);
51- }
52-
53- template <typename T>
54- const T* const_data_ptr (const Tensor& t) {
55- const void * data_ptr{};
56- TORCH_ERROR_CODE_CHECK (
57- aoti_torch_get_const_data_ptr (t.get (), const_cast <void **>(&data_ptr)));
58- return reinterpret_cast <const T*>(data_ptr);
59- }
60-
6126// TODO: When cpu is implemented in torch::stable, eliminate
6227// cpu function below.
6328inline Tensor cpu (const Tensor& self) {
64- auto sizes_ = sizes (self );
65- auto cpu_type = aoti_torch_device_type_cpu ( );
29+ auto sizes_ = self. sizes ();
30+ int32_t cpu_type = static_cast < int32_t >(torch::stable::DeviceType::CPU );
6631 int32_t dtype;
6732 TORCH_ERROR_CODE_CHECK (aoti_torch_get_dtype (self.get (), &dtype));
6833 int32_t layout;
@@ -85,8 +50,8 @@ inline Tensor cpu(const Tensor& self) {
8550
8651// TODO:
8752inline Tensor cuda (const Tensor& self, int32_t cuda_index) {
88- auto sizes_ = sizes (self );
89- auto cuda_type = aoti_torch_device_type_cuda ( );
53+ auto sizes_ = self. sizes ();
54+ int32_t cuda_type = static_cast < int32_t >(torch::stable::DeviceType::CUDA );
9055 int32_t dtype;
9156 TORCH_ERROR_CODE_CHECK (aoti_torch_get_dtype (self.get (), &dtype));
9257 int32_t layout;
@@ -169,7 +134,9 @@ T item(const Tensor& self) {
169134 STD_TORCH_CHECK (
170135 self.numel () == 1 , " item requires single element tensor input" );
171136 if (self.is_cpu ()) {
172- return torchaudio::stable::const_data_ptr<T>(self)[0 ];
137+ // TODO: use `return self.const_data_ptr<T>()[0];` after torch
138+ // stable supports const_data_ptr templates.
139+ return reinterpret_cast <const T*>(self.const_data_ptr ())[0 ];
173140#ifdef USE_CUDA
174141 } else if (self.is_cuda ()) {
175142 T value;
0 commit comments