1717#include < c10/cuda/CUDAException.h>
1818#endif
1919
20- using torch::stable::Tensor;
21-
2220namespace torchaudio ::stable {
2321
24- using Layout = int32_t ;
22+ using torch::stable::Tensor ;
2523
26- // TODO: When cpu is implemented in torch::stable, eliminate
27- // cpu function below.
24+ // TODO: When cpu op is implemented in torch::stable, eliminate cpu
25+ // function below.
2826inline Tensor cpu (const Tensor& self) {
2927 auto sizes_ = self.sizes ();
3028 int32_t cpu_type = static_cast <int32_t >(torch::stable::DeviceType::CPU);
@@ -48,7 +46,8 @@ inline Tensor cpu(const Tensor& self) {
4846 return result;
4947}
5048
51- // TODO:
49+ // TODO: When cuda op is implemented in torch::stable, eliminate cuda
50+ // function below.
5251inline Tensor cuda (const Tensor& self, int32_t cuda_index) {
5352 auto sizes_ = self.sizes ();
5453 int32_t cuda_type = static_cast <int32_t >(torch::stable::DeviceType::CUDA);
@@ -72,61 +71,6 @@ inline Tensor cuda(const Tensor& self, int32_t cuda_index) {
7271 return result;
7372}
7473
75- // TODO: remove when torch::stable provides new_zeros
76- inline Tensor new_zeros (
77- const Tensor& self,
78- std::vector<int64_t > size,
79- std::optional<c10::ScalarType> dtype = std::nullopt ,
80- std::optional<Layout> layout = std::nullopt ,
81- std::optional<torch::stable::Device> device = std::nullopt ,
82- std::optional<bool > pin_memory = std::nullopt ) {
83- int32_t target_dtype{};
84- if (dtype.has_value ()) {
85- target_dtype = torch::stable::detail::to<int32_t >(
86- torch::stable::detail::from (dtype.value ()));
87- } else {
88- TORCH_ERROR_CODE_CHECK (aoti_torch_get_dtype (self.get (), &target_dtype));
89- }
90-
91- Layout layout_;
92- if (layout.has_value ()) {
93- layout_ = layout.value ();
94- } else {
95- TORCH_ERROR_CODE_CHECK (aoti_torch_get_layout (self.get (), &layout_));
96- }
97-
98- int32_t device_type;
99- torch::stable::DeviceIndex device_index = 0 ;
100- if (device.has_value ()) {
101- auto device_ = device.value ();
102- device_type = static_cast <int32_t >(device_.type ());
103- device_index = device_.index ();
104- } else {
105- TORCH_ERROR_CODE_CHECK (
106- aoti_torch_get_device_type (self.get (), &device_type));
107- TORCH_ERROR_CODE_CHECK (
108- aoti_torch_get_device_index (self.get (), &device_index));
109- }
110-
111- // TODO: pin_memory
112-
113- AtenTensorHandle ret0;
114- TORCH_ERROR_CODE_CHECK (aoti_torch_aten_new_empty (
115- self.get (),
116- size.data (),
117- static_cast <int64_t >(size.size ()),
118- &target_dtype,
119- &layout_,
120- &device_type,
121- device_index,
122- nullptr , // pin_memory (nullptr for default)
123- &ret0));
124-
125- auto result = Tensor (ret0);
126- torch::stable::zero_ (result);
127- return result;
128- }
129-
13074// An analog of item template function defined in
13175// ATen/templates/TensorBody.h
13276template <typename T>
0 commit comments